Skip to content

Commit

Permalink
Prepare new ChangeEvents
Browse files Browse the repository at this point in the history
  • Loading branch information
moubctez committed Oct 23, 2024
1 parent 2488d81 commit 3450b3a
Show file tree
Hide file tree
Showing 42 changed files with 287 additions and 322 deletions.
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure().compile_with_config(
tonic_build::configure().compile_protos_with_config(
config,
&[
"proto/core/auth.proto",
Expand Down
28 changes: 15 additions & 13 deletions src/appstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,24 @@ use webauthn_rs::prelude::*;

use crate::{
auth::failed_login::FailedLoginMap,
db::{models::wireguard::ChangeEvent, AppEvent, WebHook},
db::models::{
webhook::{AppEvent, WebHook},
wireguard::ChangeEvent,
},
mail::Mail,
server_config,
};

#[derive(Clone)]
pub struct AppState {
pub(crate) struct AppState {
pub pool: PgPool,
tx: UnboundedSender<AppEvent>,
wireguard_tx: Sender<ChangeEvent>,
pub mail_tx: UnboundedSender<Mail>,
pub webauthn: Arc<Webauthn>,
pub user_agent_parser: Arc<UserAgentParser>,
pub failed_logins: Arc<Mutex<FailedLoginMap>>,
// A key for secret cookies.
key: Key,
}

Expand All @@ -44,26 +48,24 @@ impl AppState {
}
}

/// Handle webhook events
/// Handle webhook events. This is ran on a separate asynchronous task.
async fn handle_triggers(pool: PgPool, mut rx: UnboundedReceiver<AppEvent>) {
let reqwest_client = Client::builder().user_agent("reqwest").build().unwrap();
while let Some(msg) = rx.recv().await {
debug!("Webhook triggered. Retrieving webhooks.");
if let Ok(webhooks) = WebHook::all_enabled(&pool, &msg).await {
info!("Found webhooks: {webhooks:#?}");
let (payload, event) = match msg {
AppEvent::UserCreated(user) => (json!(user), "user_created"),
AppEvent::UserModified(user) => (json!(user), "user_modified"),
AppEvent::UserDeleted(username) => {
(json!({ "username": username }), "user_deleted")
}
AppEvent::HWKeyProvision(data) => (json!(data), "user_keys"),
let payload = match msg {
AppEvent::UserCreated(ref user) => json!(user),
AppEvent::UserModified(ref user) => json!(user),
AppEvent::UserDeleted(ref username) => json!({ "username": username }),
AppEvent::HWKeyProvision(ref data) => json!(data),
};
for webhook in webhooks {
match reqwest_client
.post(&webhook.url)
.bearer_auth(&webhook.token)
.header("x-defguard-event", event)
.header("x-defguard-event", msg.name())
.json(&payload)
.send()
.await
Expand All @@ -81,14 +83,14 @@ impl AppState {
}

/// Sends given `ChangeEvent` to be handled by gateway (over gRPC).
pub fn send_change_event(&self, event: ChangeEvent) {
pub(crate) fn send_change_event(&self, event: ChangeEvent) {
if let Err(err) = self.wireguard_tx.send(event) {
error!("Error sending change event {err}");
}
}

/// Sends multiple events to be handled by gateway (over gRPC).
pub fn send_multiple_change_events(&self, events: Vec<ChangeEvent>) {
pub(crate) fn send_multiple_change_events(&self, events: Vec<ChangeEvent>) {
debug!("Sending {} change events", events.len());
for event in events {
self.send_change_event(event);
Expand Down
10 changes: 5 additions & 5 deletions src/auth/failed_login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use thiserror::Error;
use utoipa::ToSchema;

// Time window in seconds
const FAILED_LOGIN_WINDOW: i64 = 60;
const FAILED_LOGIN_WINDOW: Duration = Duration::seconds(60);
// Failed login count threshold
const FAILED_LOGIN_COUNT: u32 = 5;
// How long (in seconds) to lock users out after crossing the threshold
const FAILED_LOGIN_TIMEOUT: i64 = 5 * 60;
const FAILED_LOGIN_TIMEOUT: Duration = Duration::seconds(5 * 60);

#[derive(Debug, Error, ToSchema)]
#[error("Too many login attempts")]
Expand Down Expand Up @@ -58,16 +58,16 @@ impl FailedLogin {
// Check if user login attempt should be stopped
fn should_prevent_login(&self) -> bool {
self.attempt_count >= FAILED_LOGIN_COUNT
&& self.time_since_last_attempt() <= Duration::seconds(FAILED_LOGIN_TIMEOUT)
&& self.time_since_last_attempt() <= FAILED_LOGIN_TIMEOUT
}

// Check if attempt counter can be reset.
// Counter can be reset after enough time has passed since the initial attempt.
// If user was blocked we also check if enough time (timeout) has passed since last attempt.
fn should_reset_counter(&self) -> bool {
self.time_since_first_attempt() > Duration::seconds(FAILED_LOGIN_WINDOW)
self.time_since_first_attempt() > FAILED_LOGIN_WINDOW
&& self.attempt_count < FAILED_LOGIN_COUNT
|| self.time_since_last_attempt() > Duration::seconds(FAILED_LOGIN_TIMEOUT)
|| self.time_since_last_attempt() > FAILED_LOGIN_TIMEOUT
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use serde::{Deserialize, Serialize};

use crate::{
appstate::AppState,
db::{Group, Id, OAuth2AuthorizedApp, OAuth2Token, Session, SessionState, User},
db::{
models::{group::Group, user::User},
Id, OAuth2AuthorizedApp, OAuth2Token, Session, SessionState,
},
error::WebError,
handlers::SESSION_COOKIE_NAME,
server_config,
Expand Down
7 changes: 5 additions & 2 deletions src/bin/defguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use std::{
use defguard::{
auth::failed_login::FailedLoginMap,
config::{Command, DefGuardConfig},
db::{init_db, models::wireguard::ChangeEvent, AppEvent, Settings, User},
db::{
init_db,
models::{settings::Settings, user::User, webhook::AppEvent, wireguard::ChangeEvent},
},
enterprise::license::{run_periodic_license_check, set_cached_license, License},
grpc::{
run_grpc_bidi_stream, run_grpc_gateway_stream, run_grpc_server, GatewayMap, WorkerState,
Expand Down Expand Up @@ -118,7 +121,7 @@ async fn main() -> Result<(), anyhow::Error> {
tokio::select! {
res = run_grpc_gateway_stream(pool.clone(), events_tx.clone()) => error!("Gateway gRPC stream returned early: {res:#?}"),
res = run_grpc_bidi_stream(pool.clone(), events_tx.clone(), mail_tx.clone(), user_agent_parser.clone()), if config.proxy_url.is_some() => error!("Proxy gRPC stream returned early: {res:#?}"),
res = run_grpc_server(Arc::clone(&worker_state), pool.clone(), Arc::clone(&gateway_map), mail_tx.clone(), grpc_cert, grpc_key, failed_logins.clone()) => error!("gRPC server returned early: {res:#?}"),
res = run_grpc_server(Arc::clone(&worker_state), pool.clone(), grpc_cert, grpc_key, failed_logins.clone()) => error!("gRPC server returned early: {res:#?}"),
res = run_web_server(worker_state, gateway_map, webhook_tx, webhook_rx, events_tx.clone(), mail_tx, pool.clone(), user_agent_parser, failed_logins) => error!("Web server returned early: {res:#?}"),
res = run_mail_handler(mail_rx, pool.clone()) => error!("Mail handler returned early: {res:#?}"),
res = run_periodic_peer_disconnect(pool.clone(), events_tx) => error!("Periodic peer disconnect task returned early: {res:#?}"),
Expand Down
7 changes: 1 addition & 6 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub type Id = i64;

/// Initializes and migrates postgres database. Returns DB pool object.
pub async fn init_db(host: &str, port: u16, name: &str, user: &str, password: &str) -> PgPool {
info!("Initializing DB pool");
info!("Initializing pool of database connections");
let opts = PgConnectOptions::new()
.host(host)
.port(port)
Expand All @@ -28,14 +28,9 @@ pub async fn init_db(host: &str, port: u16, name: &str, user: &str, password: &s

pub use models::{
device::{AddDevice, Device},
group::Group,
oauth2authorizedapp::OAuth2AuthorizedApp,
oauth2token::OAuth2Token,
session::{Session, SessionState},
settings::Settings,
user::{MFAMethod, User},
wallet::Wallet,
webauthn::WebAuthn,
webhook::{AppEvent, HWKeyUserData, WebHook},
MFAInfo, UserDetails, UserInfo,
};
9 changes: 5 additions & 4 deletions src/db/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ use utoipa::ToSchema;

use self::{
device::UserDevice,
group::Group,
user::{MFAMethod, User},
};
use super::{Group, Id};
use super::Id;

#[cfg(feature = "openid")]
#[derive(Deserialize, Serialize)]
Expand Down Expand Up @@ -83,7 +84,7 @@ pub struct UserInfo {
}

impl UserInfo {
pub async fn from_user(pool: &PgPool, user: &User<Id>) -> Result<Self, SqlxError> {
pub(crate) async fn from_user(pool: &PgPool, user: &User<Id>) -> Result<Self, SqlxError> {
let groups = user.member_of_names(pool).await?;
let authorized_apps = user.oauth2authorizedapps(pool).await?;

Expand Down Expand Up @@ -236,7 +237,7 @@ impl MFAInfo {
}

#[must_use]
pub fn mfa_available(&self) -> bool {
pub(crate) fn mfa_available(&self) -> bool {
self.webauthn_available
|| self.totp_available
|| self.web3_available
Expand All @@ -249,7 +250,7 @@ impl MFAInfo {
}

#[must_use]
pub fn list_available_methods(&self) -> Option<Vec<MFAMethod>> {
pub(crate) fn list_available_methods(&self) -> Option<Vec<MFAMethod>> {
if !self.mfa_available() {
return None;
}
Expand Down
13 changes: 7 additions & 6 deletions src/db/models/webhook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ pub enum AppEvent {
UserDeleted(String),
HWKeyProvision(HWKeyUserData),
}
/// User data send on HWKeyProvision AppEvent

/// User data sent on HWKeyProvision AppEvent
#[derive(Debug, Serialize)]
pub struct HWKeyUserData {
pub username: String,
Expand All @@ -27,16 +28,16 @@ impl AppEvent {
#[must_use]
pub fn name(&self) -> &str {
match self {
Self::UserCreated(_) => "user created",
Self::UserModified(_) => "user modified",
Self::UserDeleted(_) => "user deleted",
Self::HWKeyProvision(_) => "hwkey provisioned",
Self::UserCreated(_) => "user_created",
Self::UserModified(_) => "user_modified",
Self::UserDeleted(_) => "user_deleted",
Self::HWKeyProvision(_) => "user_keys",
}
}

/// Database column name.
#[must_use]
pub fn column_name(&self) -> &str {
pub(crate) fn column_name(&self) -> &str {
match self {
Self::UserCreated(_) => "on_user_created",
Self::UserModified(_) => "on_user_modified",
Expand Down
14 changes: 9 additions & 5 deletions src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,17 @@ impl DateTimeAggregation {
}

#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum ChangeEvent {
NetworkCreated(Id, WireguardNetwork<Id>),
NetworkModified(Id, WireguardNetwork<Id>, Vec<Peer>),
NetworkCreated(WireguardNetwork<Id>),
NetworkModified(WireguardNetwork<Id>, Vec<Peer>),
NetworkDeleted(Id, String),
DeviceCreated(DeviceInfo),
DeviceModified(DeviceInfo),
DeviceDeleted(DeviceInfo),
GatewayCreated(Id),
GatewayModified(Id),
GatewayDeleted(Id),
}

/// Stores configuration required to setup a WireGuard network
Expand All @@ -94,9 +98,9 @@ pub struct WireguardNetwork<I = NoId> {
pub peer_disconnect_threshold: i32,
}

pub struct WireguardKey {
pub private: String,
pub public: String,
pub(crate) struct WireguardKey {
pub(crate) private: String,
pub(crate) public: String,
}

impl fmt::Display for WireguardNetwork<NoId> {
Expand Down
5 changes: 1 addition & 4 deletions src/enterprise/handlers/enterprise_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ pub(crate) async fn get_enterprise_settings(
"User {} retrieved enterprise settings",
session.user.username
);
Ok(ApiResponse {
json: json!(settings),
status: StatusCode::OK,
})
Ok(ApiResponse::new(json!(settings), StatusCode::OK))
}

pub(crate) async fn patch_enterprise_settings(
Expand Down
9 changes: 5 additions & 4 deletions src/enterprise/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use axum::{
extract::{FromRef, FromRequestParts},
http::{request::Parts, StatusCode},
};
use serde_json::json;

use super::{db::models::enterprise_settings::EnterpriseSettings, license::get_cached_license};
use crate::{appstate::AppState, error::WebError};
Expand Down Expand Up @@ -50,12 +51,12 @@ pub(crate) async fn check_enterprise_status() -> Result<ApiResponse, WebError> {
}
)
});
Ok(ApiResponse {
json: serde_json::json!({ "enabled": valid,
Ok(ApiResponse::new(
json!({ "enabled": valid,
"license_info": license_info
}),
status: StatusCode::OK,
})
StatusCode::OK,
))
}

#[async_trait]
Expand Down
Loading

0 comments on commit 3450b3a

Please sign in to comment.