Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add burst guard to macOS route monitor #5225

Merged
merged 2 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions talpid-core/src/offline/macos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ use std::sync::{
};
use talpid_routing::{DefaultRouteEvent, RouteManagerHandle};

/// How long to wait before announcing changes to the offline state
//const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(2);

#[derive(err_derive::Error, Debug)]
pub enum Error {
#[error(display = "Failed to initialize route monitor")]
Expand Down Expand Up @@ -120,11 +117,6 @@ pub async fn spawn_monitor(
None => return,
};

// Debounce event updates
// FIXME: Debounce is disabled because the DNS config can get messed up
// when switching between networks otherwise.
//tokio::time::sleep(DEBOUNCE_INTERVAL).await;

if prev_notified.swap(new_connectivity, Ordering::AcqRel) == new_connectivity {
// We don't care about network changes here
return;
Expand Down
95 changes: 95 additions & 0 deletions talpid-routing/src/debounce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#![allow(dead_code)]

use std::{
sync::mpsc::{channel, RecvTimeoutError, Sender},
time::{Duration, Instant},
};

/// BurstGuard is a wrapper for a function that protects that function from being called too many
/// times in a short amount of time. To call the function use `burst_guard.trigger()`, at that point
/// `BurstGuard` will wait for `buffer_period` and if no more calls to `trigger` are made then it
/// will call the wrapped function. If another call to `trigger` is made during this wait then it
/// will wait another `buffer_period`, this happens over and over until either
/// `longest_buffer_period` time has elapsed or until no call to `trigger` has been made in
/// `buffer_period`. At which point the wrapped function will be called.
pub struct BurstGuard {
sender: Sender<BurstGuardEvent>,
}

enum BurstGuardEvent {
Trigger,
Shutdown(Sender<()>),
}

impl BurstGuard {
pub fn new<F: Fn() + Send + 'static>(callback: F) -> Self {
/// This is the period of time the `BurstGuard` will wait for a new trigger to be sent
/// before it calls the callback.
const BURST_BUFFER_PERIOD: Duration = Duration::from_millis(200);
/// This is the longest period that the `BurstGuard` will wait from the first trigger till
/// it calls the callback.
const BURST_LONGEST_BUFFER_PERIOD: Duration = Duration::from_secs(2);

let (sender, listener) = channel();
std::thread::spawn(move || {
// The `stop` implementation assumes that this thread will not call `callback` again
// if the listener has been dropped.
while let Ok(message) = listener.recv() {
match message {
BurstGuardEvent::Trigger => {
let start = Instant::now();
loop {
match listener.recv_timeout(BURST_BUFFER_PERIOD) {
Ok(BurstGuardEvent::Trigger) => {
if start.elapsed() >= BURST_LONGEST_BUFFER_PERIOD {
callback();
break;
}
}
Ok(BurstGuardEvent::Shutdown(tx)) => {
let _ = tx.send(());
return;
}
Err(RecvTimeoutError::Timeout) => {
callback();
break;
}
Err(RecvTimeoutError::Disconnected) => {
break;
}
}
}
}
BurstGuardEvent::Shutdown(tx) => {
let _ = tx.send(());
return;
}
}
}
});
Self { sender }
}

/// When `stop` returns an then the `BurstGuard` thread is guaranteed to not make any further
/// calls to `callback`.
pub fn stop(self) {
let (sender, listener) = channel();
// If we could not send then it means the thread has already shut down and we can return
if self.sender.send(BurstGuardEvent::Shutdown(sender)).is_ok() {
// We do not care what the result is, if it is OK it means the thread shut down, if
// it is Err it also means it shut down.
let _ = listener.recv();
}
}

/// Stop without waiting for in-flight events to complete.
pub fn stop_nonblocking(self) {
let (sender, _listener) = channel();
let _ = self.sender.send(BurstGuardEvent::Shutdown(sender));
}

/// Asynchronously trigger burst
pub fn trigger(&self) {
self.sender.send(BurstGuardEvent::Trigger).unwrap();
}
}
3 changes: 3 additions & 0 deletions talpid-routing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
use ipnetwork::IpNetwork;
use std::{fmt, net::IpAddr};

#[cfg(any(target_os = "windows", target_os = "macos"))]
mod debounce;

#[cfg(target_os = "windows")]
#[path = "windows/mod.rs"]
mod imp;
Expand Down
59 changes: 40 additions & 19 deletions talpid-routing/src/unix/macos/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{NetNode, Node, RequiredRoute, Route};
use crate::{debounce::BurstGuard, NetNode, Node, RequiredRoute, Route};

use futures::{
channel::mpsc,
Expand All @@ -7,11 +7,11 @@ use futures::{
};
use ipnetwork::IpNetwork;
use nix::sys::socket::{AddressFamily, SockaddrLike, SockaddrStorage};
use std::pin::Pin;
use std::{
collections::{BTreeMap, HashSet},
time::Duration,
};
use std::{pin::Pin, sync::Weak};
use talpid_types::ErrorExt;
use watch::RoutingTable;

Expand Down Expand Up @@ -85,15 +85,24 @@ pub struct RouteManagerImpl {
applied_routes: BTreeMap<RouteDestination, RouteMessage>,
v4_default_route: Option<data::RouteMessage>,
v6_default_route: Option<data::RouteMessage>,
update_trigger: BurstGuard,
default_route_listeners: Vec<mpsc::UnboundedSender<DefaultRouteEvent>>,
check_default_routes_restored: Pin<Box<dyn FusedStream<Item = ()> + Send>>,
}

impl RouteManagerImpl {
/// Create new route manager
#[allow(clippy::unused_async)]
pub async fn new() -> Result<Self> {
pub(crate) async fn new(
manage_tx: Weak<mpsc::UnboundedSender<RouteManagerCommand>>,
) -> Result<Self> {
let routing_table = RoutingTable::new().map_err(Error::RoutingTable)?;
let update_trigger = BurstGuard::new(move || {
let Some(manage_tx) = manage_tx.upgrade() else {
return;
};
let _ = manage_tx.unbounded_send(RouteManagerCommand::RefreshRoutes);
});
Ok(Self {
routing_table,
non_tunnel_routes: HashSet::new(),
Expand All @@ -102,6 +111,7 @@ impl RouteManagerImpl {
applied_routes: BTreeMap::new(),
v4_default_route: None,
v6_default_route: None,
update_trigger,
default_route_listeners: vec![],
check_default_routes_restored: Box::pin(futures::stream::pending()),
})
Expand Down Expand Up @@ -129,10 +139,12 @@ impl RouteManagerImpl {
);
});

let mut completion_tx = None;

loop {
futures::select_biased! {
route_message = self.routing_table.next_message().fuse() => {
self.handle_route_message(route_message).await;
self.handle_route_message(route_message);
}

_ = self.check_default_routes_restored.next() => {
Expand All @@ -148,11 +160,8 @@ impl RouteManagerImpl {
command = manage_rx.next() => {
match command {
Some(RouteManagerCommand::Shutdown(tx)) => {
if let Err(err) = self.cleanup_routes().await {
log::error!("Failed to clean up routes: {err}");
}
let _ = tx.send(());
return;
completion_tx = Some(tx);
break;
},

Some(RouteManagerCommand::NewDefaultRouteListener(tx)) => {
Expand Down Expand Up @@ -214,6 +223,11 @@ impl RouteManagerImpl {
log::error!("Failed to clean up rotues: {err}");
}
},
Some(RouteManagerCommand::RefreshRoutes) => {
if let Err(error) = self.refresh_routes().await {
log::error!("Failed to refresh routes: {error}")
}
},
None => {
break;
}
Expand All @@ -225,6 +239,12 @@ impl RouteManagerImpl {
if let Err(err) = self.cleanup_routes().await {
log::error!("Failed to clean up routing table when shutting down: {err}");
}

self.update_trigger.stop_nonblocking();

if let Some(tx) = completion_tx {
let _ = tx.send(());
}
}

async fn add_required_routes(&mut self, required_routes: HashSet<RequiredRoute>) -> Result<()> {
Expand Down Expand Up @@ -287,7 +307,7 @@ impl RouteManagerImpl {
Ok(())
}

async fn handle_route_message(
fn handle_route_message(
&mut self,
message: std::result::Result<RouteSocketMessage, watch::Error>,
) {
Expand All @@ -303,18 +323,19 @@ impl RouteManagerImpl {
log::error!("Failed to process deleted route: {err}");
}
}

if let Err(error) = self.handle_route_socket_message().await {
log::error!("Failed to process route change: {error}");
if route.errno() == 0 && route.is_default().unwrap_or(true) {
self.update_trigger.trigger();
}
}
Ok(RouteSocketMessage::AddRoute(_))
| Ok(RouteSocketMessage::ChangeRoute(_))
| Ok(RouteSocketMessage::AddAddress(_) | RouteSocketMessage::DeleteAddress(_)) => {
if let Err(error) = self.handle_route_socket_message().await {
log::error!("Failed to process route/address change: {error}");
Ok(RouteSocketMessage::AddRoute(route))
| Ok(RouteSocketMessage::ChangeRoute(route)) => {
if route.errno() == 0 && route.is_default().unwrap_or(true) {
self.update_trigger.trigger();
}
}
Ok(RouteSocketMessage::AddAddress(_) | RouteSocketMessage::DeleteAddress(_)) => {
self.update_trigger.trigger();
}
// ignore all other message types
Ok(_) => {}
Err(err) => {
Expand All @@ -329,7 +350,7 @@ impl RouteManagerImpl {
/// * At the same time, update the route used by non-tunnel interfaces to reach the relay/VPN
/// server. The gateway of the relay route is set to the first interface in the network
/// service order that has a working ifscoped default route.
async fn handle_route_socket_message(&mut self) -> Result<()> {
async fn refresh_routes(&mut self) -> Result<()> {
self.update_best_default_route(interface::Family::V4)
.await?;
self.update_best_default_route(interface::Family::V6)
Expand Down
11 changes: 8 additions & 3 deletions talpid-routing/src/unix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use futures::channel::{
mpsc::{self, UnboundedSender},
oneshot,
};
use std::{collections::HashSet, io};
use std::{collections::HashSet, io, sync::Arc};

#[cfg(any(target_os = "linux", target_os = "macos"))]
use futures::stream::Stream;
Expand Down Expand Up @@ -55,7 +55,7 @@ pub enum Error {
/// Handle to a route manager.
#[derive(Clone)]
pub struct RouteManagerHandle {
tx: UnboundedSender<RouteManagerCommand>,
tx: Arc<UnboundedSender<RouteManagerCommand>>,
}

impl RouteManagerHandle {
Expand Down Expand Up @@ -181,6 +181,8 @@ pub(crate) enum RouteManagerCommand {
ClearRoutes,
Shutdown(oneshot::Sender<()>),
#[cfg(target_os = "macos")]
RefreshRoutes,
#[cfg(target_os = "macos")]
NewDefaultRouteListener(oneshot::Sender<mpsc::UnboundedReceiver<DefaultRouteEvent>>),
#[cfg(target_os = "macos")]
GetDefaultRoutes(oneshot::Sender<(Option<Route>, Option<Route>)>),
Expand Down Expand Up @@ -227,7 +229,7 @@ pub enum CallbackMessage {
/// If a destination has to be routed through the default node,
/// the route will be adjusted dynamically when the default route changes.
pub struct RouteManager {
manage_tx: Option<UnboundedSender<RouteManagerCommand>>,
manage_tx: Option<Arc<UnboundedSender<RouteManagerCommand>>>,
runtime: tokio::runtime::Handle,
}

Expand All @@ -238,11 +240,14 @@ impl RouteManager {
#[cfg(target_os = "linux")] table_id: u32,
) -> Result<Self, Error> {
let (manage_tx, manage_rx) = mpsc::unbounded();
let manage_tx = Arc::new(manage_tx);
let manager = imp::RouteManagerImpl::new(
#[cfg(target_os = "linux")]
fwmark,
#[cfg(target_os = "linux")]
table_id,
#[cfg(target_os = "macos")]
Arc::downgrade(&manage_tx),
)
.await?;
tokio::spawn(manager.run(manage_rx));
Expand Down
Loading