Skip to content

Commit

Permalink
Merge pull request #659 from DefGuard/dev
Browse files Browse the repository at this point in the history
merge dev -> main
  • Loading branch information
t-aleksander authored Jul 3, 2024
2 parents 7e6ad6f + b6f9842 commit 39015c7
Show file tree
Hide file tree
Showing 17 changed files with 563 additions and 559 deletions.
891 changes: 455 additions & 436 deletions Cargo.lock

Large diffs are not rendered by default.

27 changes: 15 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"
license = "Apache-2.0"
homepage = "https://defguard.net/"
repository = "https://github.com/DefGuard/defguard"
rust-version = "1.76"

[workspace]

Expand All @@ -18,7 +19,7 @@ axum-extra = { version = "0.9", features = [
"cookie-private",
"typed-header",
] }
base64 = "0.21"
base64 = "0.22"
chrono = { version = "0.4", default-features = false, features = [
"clock",
"serde",
Expand All @@ -30,25 +31,26 @@ ethers-core = "2.0"
humantime = "2.1"
# match ipnetwork version from sqlx
ipnetwork = { version = "0.20", features = ["serde"] }
jsonwebtoken = "9.2"
jsonwebtoken = "9.3"
ldap3 = { version = "0.11", default-features = false, features = ["tls"] }
lettre = { version = "0.11", features = ["tokio1", "tokio1-native-tls"] }
md4 = "0.10"
mime_guess = "2.0"
model_derive = { path = "model-derive" }
openidconnect = { version = "3.4", default-features = false, optional = true }
openidconnect = { version = "3.5", default-features = false, optional = true }
otpauth = "0.4"
prost = "0.12"
pulldown-cmark = "0.9"
pulldown-cmark = "0.11"
rand = "0.8"
rand_core = { version = "0.6", default-features = false, features = [
"getrandom",
] }
# TODO: update reqwest when openidconnect also depends on http >= 1.0.
reqwest = { version = "0.11", features = ["json"] }
rsa = { version = "0.9", features = ["pem"] }
rust-embed = { version = "8.4", features = ["include-exclude"] }
rust-ini = "0.20"
secp256k1 = { version = "0.28", features = [
secp256k1 = { version = "0.29", features = [
"recovery",
"rand-std",
"global-context",
Expand All @@ -69,7 +71,7 @@ sqlx = { version = "0.7", features = [
] }
ssh-key = "0.6"
struct-patch = "0.4"
tera = "1.19"
tera = "1.20"
thiserror = "1.0"
# match axum-extra -> cookies
time = { version = "0.3", default-features = false }
Expand All @@ -88,16 +90,16 @@ tower-http = { version = "0.5", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uaparser = "0.6"
uuid = { version = "1.4", features = ["v4"] }
webauthn-authenticator-rs = { version = "0.4" }
webauthn-rs = { version = "0.4", features = [
uuid = { version = "1.9", features = ["v4"] }
webauthn-authenticator-rs = { version = "0.5" }
webauthn-rs = { version = "0.5", features = [
"danger-allow-state-serialisation",
] }
webauthn-rs-proto = "0.4"
webauthn-rs-proto = "0.5"
x25519-dalek = { version = "2.0", features = ["static_secrets"] }

[dev-dependencies]
bytes = "1.5"
bytes = "1.6"
claims = "0.7"
matches = "0.1"
regex = "1.10"
Expand All @@ -108,7 +110,8 @@ reqwest = { version = "0.11", features = [
"multipart",
"rustls-tls",
], default-features = false }
serde_qs = "0.12"
serde_qs = "0.13"
webauthn-authenticator-rs = { version = "0.5", features = ["softpasskey"] }

[build-dependencies]
prost-build = "0.12"
Expand Down
4 changes: 4 additions & 0 deletions migrations/20240216195802_authentication_key.down.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
DROP TABLE authentication_key;

DROP TABLE yubikey;

DROP TYPE authentication_key_type;

ALTER TABLE "user"
ADD pgp_key text NULL,
ADD pgp_cert_id text NULL,
Expand Down
2 changes: 0 additions & 2 deletions rust-toolchain.toml

This file was deleted.

3 changes: 1 addition & 2 deletions src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,9 @@ impl Claims {

fn get_secret(claims_type: ClaimsType) -> String {
let env_var = match claims_type {
ClaimsType::Auth => AUTH_SECRET_ENV,
ClaimsType::Auth | ClaimsType::DesktopClient => AUTH_SECRET_ENV,
ClaimsType::Gateway => GATEWAY_SECRET_ENV,
ClaimsType::YubiBridge => YUBIBRIDGE_SECRET_ENV,
ClaimsType::DesktopClient => AUTH_SECRET_ENV,
};
env::var(env_var).unwrap_or_default()
}
Expand Down
18 changes: 9 additions & 9 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,16 @@ impl DefGuardConfig {

fn validate_secret_key(&self) {
let secret_key = self.secret_key.expose_secret();
if secret_key.trim().len() != secret_key.len() {
panic!("SECRET_KEY cannot have leading and trailing space",);
}
assert!(
secret_key.trim().len() == secret_key.len(),
"SECRET_KEY cannot have leading and trailing space",
);

if secret_key.len() < 64 {
panic!(
"SECRET_KEY must be at least 64 characters long, provided value has {} characters",
secret_key.len()
);
}
assert!(
secret_key.len() >= 64,
"SECRET_KEY must be at least 64 characters long, provided value has {} characters",
secret_key.len()
);
}

/// Try PKCS#1 and PKCS#8 PEM formats.
Expand Down
6 changes: 3 additions & 3 deletions src/db/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,15 @@ impl UserInfo {
transaction: &mut PgConnection,
user: &mut User,
) -> Result<bool, SqlxError> {
if self.is_active != user.is_active {
if self.is_active == user.is_active {
Ok(false)
} else {
if !self.is_active {
user.logout_all_sessions(&mut *transaction).await?;
}
user.is_active = self.is_active;
user.save(&mut *transaction).await?;
Ok(true)
} else {
Ok(false)
}
}

Expand Down
5 changes: 1 addition & 4 deletions src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,7 @@ impl WireguardNetwork {
Ok(wireguard_network_device)
} else {
error!("Device {device} not allowed in network {self}");
Err(WireguardNetworkError::DeviceNotAllowed(format!(
"{}",
device
)))
Err(WireguardNetworkError::DeviceNotAllowed(format!("{device}")))
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/grpc/desktop_client_mfa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl ClientMfaServer {
}

/// Validate JWT and extract client pubkey
fn parse_token(&self, token: &str) -> Result<String, Status> {
fn parse_token(token: &str) -> Result<String, Status> {
let claims = Claims::from_jwt(ClaimsType::DesktopClient, token).map_err(|err| {
error!("Failed to parse JWT token: {err:?}");
Status::invalid_argument("invalid token")
Expand Down Expand Up @@ -185,7 +185,7 @@ impl ClientMfaServer {
) -> Result<ClientMfaFinishResponse, Status> {
debug!("Finishing desktop client login: {request:?}");
// get pubkey from token
let pubkey = self.parse_token(&request.token)?;
let pubkey = Self::parse_token(&request.token)?;

// fetch login session
let Some(session) = self.sessions.get(&pubkey) else {
Expand Down
2 changes: 1 addition & 1 deletion src/grpc/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ impl gateway_service_server::GatewayService for GatewayServer {
.ok_or_else(|| {
Status::new(
Code::Internal,
format!("Network with id {} not found", network_id),
format!("Network with id {network_id} not found"),
)
})?;

Expand Down
11 changes: 3 additions & 8 deletions src/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,8 @@ impl GatewayMap {
if let Some(state) = network_gateway_map.get_mut(&hostname) {
state.connected = false;
state.disconnected_at = Some(Utc::now().naive_utc());
state.send_disconnect_notification(pool)?;
debug!(
"Gateway {hostname} found in gateway map, current state: {:#?}",
state
);
state.send_disconnect_notification(pool);
debug!("Gateway {hostname} found in gateway map, current state: {state:#?}");
info!("Gateway {hostname} disconnected in network {network_id}");
return Ok(());
};
Expand Down Expand Up @@ -290,7 +287,7 @@ impl GatewayState {

/// Send gateway disconnected notification
/// Sends notification only if last notification time is bigger than specified in config
fn send_disconnect_notification(&mut self, pool: &DbPool) -> Result<(), GatewayMapError> {
fn send_disconnect_notification(&mut self, pool: &DbPool) {
debug!("Sending gateway disconnect email notification");
// Clone here because self doesn't live long enough
let name = self.name.clone();
Expand Down Expand Up @@ -327,8 +324,6 @@ impl GatewayState {
self.last_email_notification
);
};

Ok(())
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/handlers/forward_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ pub async fn forward_auth(
}
// If no session cookie provided redirect to login
info!("Valid session not found, redirecting to login page");
login_redirect(headers).await
login_redirect(headers)
}

async fn login_redirect(headers: ForwardAuthHeaders) -> Result<ForwardAuthResponse, WebError> {
fn login_redirect(headers: ForwardAuthHeaders) -> Result<ForwardAuthResponse, WebError> {
let server_url = &server_config().url; // prepare redirect URL for login page
let mut location = server_url.join("/auth/login").map_err(|err| {
error!("Failed to prepare redirect URL: {err}");
Expand Down
14 changes: 7 additions & 7 deletions src/handlers/openid_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,10 @@ fn redirect_to<T: AsRef<str>>(

/// Helper function to redirect unauthorized user to login page
/// and store information about OpenID authorize url in cookie to redirect later
async fn login_redirect(
fn login_redirect(
data: &AuthenticationRequest,
private_cookies: PrivateCookieJar,
) -> Result<(StatusCode, HeaderMap, PrivateCookieJar), WebError> {
) -> (StatusCode, HeaderMap, PrivateCookieJar) {
let config = server_config();
let base_url = config.url.join("api/v1/oauth/authorize").unwrap();
let cookie = Cookie::build((
Expand All @@ -358,7 +358,7 @@ async fn login_redirect(
.same_site(SameSite::Lax)
.http_only(true)
.max_age(Duration::minutes(10));
Ok(redirect_to("/login", private_cookies.add(cookie)))
redirect_to("/login", private_cookies.add(cookie))
}

/// Authorization Endpoint
Expand Down Expand Up @@ -400,7 +400,7 @@ pub async fn authorization(
if session.expired() {
info!("Session {} for user id {} has expired, redirecting to login", session.id, session.user_id);
let _result = session.delete(&appstate.pool).await;
login_redirect(&data, private_cookies).await
Ok(login_redirect(&data, private_cookies))
} else {
let user = User::find_by_id(&appstate.pool, session.user_id)
.await?
Expand All @@ -415,7 +415,7 @@ pub async fn authorization(
"MFA not verified for user id {}, redirecting to login",
session.user_id
);
return login_redirect(&data, private_cookies).await;
return Ok(login_redirect(&data, private_cookies));
}

// If session is present check if app is in user authorized apps.
Expand Down Expand Up @@ -462,13 +462,13 @@ pub async fn authorization(
"Session {} not found, redirecting to login page",
session_cookie.value()
);
login_redirect(&data, private_cookies).await
Ok(login_redirect(&data, private_cookies))
}

// If no session cookie provided redirect to login
} else {
info!("Session cookie not provided, redirecting to login page");
login_redirect(&data, private_cookies).await
Ok(login_redirect(&data, private_cookies))
};
}
}
Expand Down
52 changes: 22 additions & 30 deletions src/handlers/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ pub struct MappedDevices {
pub devices: Vec<MappedDevice>,
}

#[derive(Serialize)]
struct ConnectionInfo {
connected: bool,
}

#[derive(Deserialize)]
pub struct ImportNetworkData {
pub name: String,
Expand Down Expand Up @@ -425,32 +420,29 @@ pub async fn add_user_devices(
});
}

match WireguardNetwork::find_by_id(&appstate.pool, network_id).await? {
Some(network) => {
// wrap loop in transaction to abort if a device is invalid
let mut transaction = appstate.pool.begin().await?;
let events = network
.handle_mapped_devices(&mut transaction, mapped_devices)
.await?;
appstate.send_multiple_wireguard_events(events);
transaction.commit().await?;

info!(
"User {} mapped {device_count} devices for {network_id} network",
user.username,
);
if let Some(network) = WireguardNetwork::find_by_id(&appstate.pool, network_id).await? {
// wrap loop in transaction to abort if a device is invalid
let mut transaction = appstate.pool.begin().await?;
let events = network
.handle_mapped_devices(&mut transaction, mapped_devices)
.await?;
appstate.send_multiple_wireguard_events(events);
transaction.commit().await?;

Ok(ApiResponse {
json: json!({}),
status: StatusCode::CREATED,
})
}
None => {
error!("Failed to map devices, network {network_id} not found");
Err(WebError::ObjectNotFound(format!(
"Network {network_id} not found"
)))
}
info!(
"User {} mapped {device_count} devices for {network_id} network",
user.username,
);

Ok(ApiResponse {
json: json!({}),
status: StatusCode::CREATED,
})
} else {
error!("Failed to map devices, network {network_id} not found");
Err(WebError::ObjectNotFound(format!(
"Network {network_id} not found"
)))
}
}

Expand Down
Loading

0 comments on commit 39015c7

Please sign in to comment.