Skip to content

Commit

Permalink
Gateways is a list of string
Browse files Browse the repository at this point in the history
  • Loading branch information
moubctez committed Oct 21, 2024
1 parent 6f21524 commit a374e31
Show file tree
Hide file tree
Showing 14 changed files with 60 additions and 59 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion migrations/20241015074303_network_gateways.up.sql
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ALTER TABLE wireguard_network ADD COLUMN gateways inet[] NOT NULL DEFAULT array[]::inet[];
ALTER TABLE wireguard_network ADD COLUMN gateways text[] NOT NULL DEFAULT array[]::text[];
44 changes: 31 additions & 13 deletions src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::{
pub const DEFAULT_KEEPALIVE_INTERVAL: i32 = 25;
pub const DEFAULT_DISCONNECT_THRESHOLD: i32 = 180;

// Used in process of importing network from wireguard config
#[derive(Clone, Debug, Deserialize, Serialize)]
// Used in process of importing network from WireGuard config
#[derive(Debug, Deserialize, Serialize)]
pub struct MappedDevice {
pub user_id: Id,
pub name: String,
Expand Down Expand Up @@ -88,8 +88,9 @@ pub struct WireguardNetwork<I = NoId> {
pub mfa_enabled: bool,
pub keepalive_interval: i32,
pub peer_disconnect_threshold: i32,
// URLs pointing to all gateways serving gRPC
#[model(ref)]
pub gateways: Vec<IpNetwork>,
pub gateways: Vec<String>,
}

pub struct WireguardKey {
Expand Down Expand Up @@ -130,6 +131,7 @@ pub enum WireguardNetworkError {
}

impl WireguardNetwork {
#[must_use]
pub fn new(
name: String,
address: IpNetwork,
Expand All @@ -140,10 +142,10 @@ impl WireguardNetwork {
mfa_enabled: bool,
keepalive_interval: i32,
peer_disconnect_threshold: i32,
) -> Result<Self, WireguardNetworkError> {
) -> Self {
let prvkey = StaticSecret::random_from_rng(OsRng);
let pubkey = PublicKey::from(&prvkey);
Ok(Self {
Self {
id: NoId,
name,
address,
Expand All @@ -158,7 +160,7 @@ impl WireguardNetwork {
keepalive_interval,
peer_disconnect_threshold,
gateways: Vec::new(),
})
}
}

/// Try to set `address` from `&str`.
Expand All @@ -173,12 +175,12 @@ impl WireguardNetwork<Id> {
pub async fn find_by_name<'e, E>(
executor: E,
name: &str,
) -> Result<Option<Vec<Self>>, WireguardNetworkError>
) -> Result<Option<Vec<Self>>, SqlxError>
where
E: PgExecutor<'e>,
{
let networks = query_as!(
WireguardNetwork,
Self,
"SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, allowed_ips, \
connected_at, mfa_enabled, keepalive_interval, peer_disconnect_threshold, gateways \
FROM wireguard_network WHERE name = $1",
Expand All @@ -194,6 +196,21 @@ impl WireguardNetwork<Id> {
})
}

/// Fetch all networks with MFA protection turned on.
pub(crate) async fn all_mfa_enabled<'e, E>(executor: E) -> Result<Vec<Self>, SqlxError>
where
E: PgExecutor<'e>,
{
query_as!(
Self,
"SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, allowed_ips, \
connected_at, mfa_enabled, keepalive_interval, peer_disconnect_threshold, gateways \
FROM wireguard_network WHERE mfa_enabled = true",
)
.fetch_all(executor)
.await
}

/// Run `sync_allowed_devices()` on all WireGuard networks.
pub async fn sync_all_networks(app: &AppState) -> Result<(), WireguardNetworkError> {
info!("Syncing allowed devices for all WireGuard locations");
Expand Down Expand Up @@ -394,7 +411,7 @@ impl WireguardNetwork<Id> {
Ok(wireguard_network_device)
} else {
info!("Device {device} not allowed in network {self}");
Err(WireguardNetworkError::DeviceNotAllowed(format!("{device}")))
Err(WireguardNetworkError::DeviceNotAllowed(device.to_string()))
}
}

Expand Down Expand Up @@ -561,7 +578,7 @@ impl WireguardNetwork<Id> {
pub async fn handle_mapped_devices(
&self,
transaction: &mut PgConnection,
mapped_devices: Vec<MappedDevice>,
mapped_devices: &[MappedDevice],
) -> Result<Vec<GatewayEvent>, WireguardNetworkError> {
info!("Mapping user devices for network {}", self);
// get allowed groups for network
Expand All @@ -570,7 +587,7 @@ impl WireguardNetwork<Id> {
let mut events = Vec::new();
// use a helper hashmap to avoid repeated queries
let mut user_groups = HashMap::new();
for mapped_device in &mapped_devices {
for mapped_device in mapped_devices {
debug!("Mapping device {}", mapped_device.name);
// validate device pubkey
Device::validate_pubkey(&mapped_device.wireguard_pubkey).map_err(|_| {
Expand Down Expand Up @@ -660,8 +677,9 @@ impl WireguardNetwork<Id> {
) -> Result<Option<WireguardPeerStats<Id>>, SqlxError> {
let stats = query_as!(
WireguardPeerStats,
"SELECT id, device_id \"device_id!\", collected_at \"collected_at!\", network \"network!\", \
endpoint, upload \"upload!\", download \"download!\", latest_handshake \"latest_handshake!\", allowed_ips \
"SELECT id, device_id \"device_id!\", collected_at \"collected_at!\", \
network \"network!\", endpoint, upload \"upload!\", download \"download!\", \
latest_handshake \"latest_handshake!\", allowed_ips \
FROM wireguard_peer_stats \
WHERE device_id = $1 AND network = $2 \
ORDER BY collected_at DESC \
Expand Down
30 changes: 12 additions & 18 deletions src/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,20 +374,18 @@ pub async fn run_grpc_bidi_stream(
} else {
endpoint
};
let uri = endpoint.uri();

loop {
debug!("Connecting to proxy at {}", endpoint.uri());
debug!("Connecting to proxy at {uri}",);
let mut client = ProxyClient::new(endpoint.connect_lazy());
let (tx, rx) = mpsc::unbounded_channel();
let Ok(response) = client.bidi(UnboundedReceiverStream::new(rx)).await else {
error!(
"Failed to connect to proxy @ {}, retrying in 10s",
endpoint.uri()
);
error!("Failed to connect to proxy @ {uri}, retrying in 10s",);
sleep(TEN_SECS).await;
continue;
};
info!("Connected to proxy at {}", endpoint.uri());
info!("Connected to proxy at {uri}");
let mut resp_stream = response.into_inner();
'message: loop {
match resp_stream.message().await {
Expand Down Expand Up @@ -520,18 +518,14 @@ pub async fn run_grpc_bidi_stream(
Some(core_response::Payload::InstanceInfo(response_payload))
}
Err(err) => {
match err.code() {
// Ignore the case when we are not enterprise but the client is trying to fetch the instance config,
// to avoid spamming the logs with misleading errors.
Code::FailedPrecondition => {
debug!("A client tried to fetch the instance config, but we are not enterprise.");
Some(core_response::Payload::CoreError(err.into()))
}
_ => {
error!("Instance info error {err}");
Some(core_response::Payload::CoreError(err.into()))
}
// Ignore the case when we are not enterprise but the client is trying to fetch the instance config,
// to avoid spamming the logs with misleading errors.
if err.code() == Code::FailedPrecondition {
debug!("A client tried to fetch the instance config, but we are not enterprise.");
} else {
error!("Instance info error {err}");
}
Some(core_response::Payload::CoreError(err.into()))
}
}
}
Expand All @@ -545,7 +539,7 @@ pub async fn run_grpc_bidi_stream(
tx.send(req).unwrap();
}
Err(err) => {
error!("Disconnected from proxy at {}", endpoint.uri());
error!("Disconnected from proxy at {uri}");
error!("stream error: {err}");
debug!("waiting 10s to re-establish the connection");
sleep(TEN_SECS).await;
Expand Down
9 changes: 4 additions & 5 deletions src/handlers/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl WireguardNetworkData {
}

// Used in process of importing network from WireGuard config
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Debug, Deserialize, Serialize)]
pub struct MappedDevices {
pub devices: Vec<MappedDevice>,
}
Expand Down Expand Up @@ -114,8 +114,7 @@ pub async fn create_network(
data.mfa_enabled,
data.keepalive_interval,
data.peer_disconnect_threshold,
)
.map_err(|_| WebError::Serialization("Invalid network address".into()))?;
);

let mut transaction = appstate.pool.begin().await?;
let network = network.save(&mut *transaction).await?;
Expand Down Expand Up @@ -388,7 +387,7 @@ pub async fn add_user_devices(
Path(network_id): Path<i64>,
Json(request_data): Json<MappedDevices>,
) -> ApiResult {
let mapped_devices = request_data.devices.clone();
let mapped_devices = request_data.devices;
let user = session.user;
let device_count = mapped_devices.len();

Expand All @@ -410,7 +409,7 @@ pub async fn add_user_devices(
// wrap loop in transaction to abort if a device is invalid
let mut transaction = appstate.pool.begin().await?;
let events = network
.handle_mapped_devices(&mut transaction, mapped_devices)
.handle_mapped_devices(&mut transaction, mapped_devices.as_slice())
.await?;
appstate.send_multiple_wireguard_events(events);
transaction.commit().await?;
Expand Down
5 changes: 2 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,7 @@ pub async fn init_dev_env(config: &DefGuardConfig) {
false,
DEFAULT_KEEPALIVE_INTERVAL,
DEFAULT_DISCONNECT_THRESHOLD,
)
.expect("Could not create network");
);
network.pubkey = "zGMeVGm9HV9I4wSKF9AXmYnnAIhDySyqLMuKpcfIaQo=".to_string();
network.prvkey = "MAk3d5KuB167G88HM7nGYR6ksnPMAOguAg2s5EcPp1M=".to_string();
network
Expand Down Expand Up @@ -674,7 +673,7 @@ pub async fn init_vpn_location(
false,
DEFAULT_KEEPALIVE_INTERVAL,
DEFAULT_DISCONNECT_THRESHOLD,
)?
)
.save(pool)
.await?;

Expand Down
2 changes: 1 addition & 1 deletion src/wg_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pub fn parse_wireguard_config(
false,
DEFAULT_KEEPALIVE_INTERVAL,
DEFAULT_DISCONNECT_THRESHOLD,
)?;
);
network.pubkey = pubkey;
network.prvkey = prvkey.to_string();

Expand Down
12 changes: 2 additions & 10 deletions src/wireguard_peer_disconnect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::db::{
error::ModelError,
wireguard::WireguardNetworkError,
},
Device, GatewayEvent, Id, WireguardNetwork,
Device, GatewayEvent, WireguardNetwork,
};

// How long to sleep between loop iterations
Expand Down Expand Up @@ -45,15 +45,7 @@ pub async fn run_periodic_peer_disconnect(
loop {
debug!("Starting periodic inactive device disconnect");

// get all MFA-protected locations
let locations = query_as!(
WireguardNetwork::<Id>,
"SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, allowed_ips, \
connected_at, mfa_enabled, keepalive_interval, peer_disconnect_threshold, gateways \
FROM wireguard_network WHERE mfa_enabled = true",
)
.fetch_all(&pool)
.await?;
let locations = WireguardNetwork::all_mfa_enabled(&pool).await?;

// loop over all locations
for location in locations {
Expand Down
3 changes: 1 addition & 2 deletions tests/wireguard_network_import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ async fn test_config_import() {
false,
DEFAULT_KEEPALIVE_INTERVAL,
DEFAULT_DISCONNECT_THRESHOLD,
)
.unwrap();
);
initial_network.save(&pool).await.unwrap();

// add existing devices
Expand Down

0 comments on commit a374e31

Please sign in to comment.