Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect changes to interfaces in the dynamic store #5254

Merged
merged 8 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ Line wrap the file at 100 chars. Th
#### Windows
- Correctly detect whether OS is Windows Server (primarily for logging in daemon.log).

#### macOS
- Fix connectivity issues when switching between networks or disconnecting.


## [android/2023.6] - 2023-09-25
### Fixed
Expand Down
112 changes: 78 additions & 34 deletions talpid-core/src/offline/macos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,24 @@
//! user from connecting to a relay.
//!
//! See [RouteManagerHandle::default_route_listener].
use futures::{channel::mpsc::UnboundedSender, StreamExt};
use std::sync::{Arc, Mutex};
//!
//! This offline monitor synthesizes an offline state between network switches and before coming
//! online from an offline state. This is done to work around issues with DNS being blocked due
//! to macOS's connectivity check. In the offline state, a DNS server on localhost prevents the
//! connectivity check from being blocked.
use futures::{
channel::mpsc::UnboundedSender,
future::{Fuse, FutureExt},
select, StreamExt,
};
use std::{
sync::{Arc, Mutex},
time::Duration,
};
use talpid_routing::{DefaultRouteEvent, RouteManagerHandle};

const SYNTHETIC_OFFLINE_DURATION: Duration = Duration::from_secs(1);

#[derive(err_derive::Error, Debug)]
pub enum Error {
#[error(display = "Failed to initialize route monitor")]
Expand All @@ -18,6 +32,7 @@ pub struct MonitorHandle {
_notify_tx: Arc<UnboundedSender<bool>>,
}

#[derive(Clone)]
struct ConnectivityState {
v4_connectivity: bool,
v6_connectivity: bool,
Expand Down Expand Up @@ -45,7 +60,7 @@ pub async fn spawn_monitor(
let notify_tx = Arc::new(notify_tx);

// note: begin observing before initializing the state
let mut route_listener = route_manager_handle.default_route_listener().await?;
let route_listener = route_manager_handle.default_route_listener().await?;

let (v4_connectivity, v6_connectivity) = match route_manager_handle.get_default_routes().await {
Ok((v4_route, v6_route)) => (v4_route.is_some(), v6_route.is_some()),
Expand All @@ -61,50 +76,79 @@ pub async fn spawn_monitor(
v4_connectivity,
v6_connectivity,
};
let mut real_state = state.clone();

let state = Arc::new(Mutex::new(state));

let weak_state = Arc::downgrade(&state);
let weak_notify_tx = Arc::downgrade(&notify_tx);

// Detect changes to the default route
tokio::spawn(async move {
while let Some(event) = route_listener.next().await {
let Some(state) = weak_state.upgrade() else {
break;
};
let mut state = state.lock().unwrap();
let mut timeout = Fuse::terminated();
let mut route_listener = route_listener.fuse();

loop {
select! {
_ = timeout => {
// Update shared state
let Some(state) = weak_state.upgrade() else {
break;
};
let mut state = state.lock().unwrap();
*state = real_state.clone();

if state.get_connectivity() {
log::info!("Connectivity changed: Connected");
let Some(tx) = weak_notify_tx.upgrade() else {
break;
};
let _ = tx.unbounded_send(false);
}
}

let previous_connectivity = state.get_connectivity();
route_event = route_listener.next() => {
let Some(event) = route_event else {
break;
};

// Update real state
match event {
DefaultRouteEvent::AddedOrChangedV4 => {
real_state.v4_connectivity = true;
}
DefaultRouteEvent::AddedOrChangedV6 => {
real_state.v6_connectivity = true;
}
DefaultRouteEvent::RemovedV4 => {
real_state.v4_connectivity = false;
}
DefaultRouteEvent::RemovedV6 => {
real_state.v6_connectivity = false;
}
}

match event {
DefaultRouteEvent::AddedOrChangedV4 => {
state.v4_connectivity = true;
}
DefaultRouteEvent::AddedOrChangedV6 => {
state.v6_connectivity = true;
}
DefaultRouteEvent::RemovedV4 => {
// Synthesize offline state
// Update shared state
let Some(state) = weak_state.upgrade() else {
break;
};
let mut state = state.lock().unwrap();
let previous_connectivity = state.get_connectivity();
state.v4_connectivity = false;
}
DefaultRouteEvent::RemovedV6 => {
state.v6_connectivity = false;
}
}

let new_connectivity = state.get_connectivity();
if previous_connectivity != new_connectivity {
log::info!(
"Connectivity changed: {}",
if new_connectivity {
"Connected"
} else {
"Offline"
if previous_connectivity {
let Some(tx) = weak_notify_tx.upgrade() else {
break;
};
let _ = tx.unbounded_send(true);
log::info!("Connectivity changed: Offline");
}
if real_state.get_connectivity() {
timeout = Box::pin(tokio::time::sleep(SYNTHETIC_OFFLINE_DURATION)).fuse();
}
);
let Some(tx) = weak_notify_tx.upgrade() else {
break;
};
let _ = tx.unbounded_send(!new_connectivity);
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion talpid-core/src/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use trust_dns_server::{
ServerFuture,
};

const ALLOWED_RECORD_TYPES: &[RecordType] = &[RecordType::A, RecordType::AAAA, RecordType::CNAME];
const ALLOWED_RECORD_TYPES: &[RecordType] = &[RecordType::A, RecordType::CNAME];
const CAPTIVE_PORTAL_DOMAINS: &[&str] = &["captive.apple.com", "netcts.cdn-apple.com"];

static ALLOWED_DOMAINS: Lazy<Vec<LowerName>> = Lazy::new(|| {
Expand Down
6 changes: 6 additions & 0 deletions talpid-core/src/tunnel_state_machine/connecting_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,12 @@ impl TunnelState for ConnectingState {
retry_attempt: u32,
) -> (TunnelStateWrapper, TunnelStateTransition) {
if shared_values.is_offline {
// FIXME: Temporary: Nudge route manager to update the default interface
#[cfg(target_os = "macos")]
if let Ok(handle) = shared_values.route_manager.handle() {
log::debug!("Poking route manager to update default routes");
let _ = handle.refresh_routes();
}
return ErrorState::enter(shared_values, ErrorStateCause::IsOffline);
}
match shared_values.runtime.block_on(
Expand Down
46 changes: 31 additions & 15 deletions talpid-routing/src/debounce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,39 @@ use std::{
/// `buffer_period`. At which point the wrapped function will be called.
pub struct BurstGuard {
sender: Sender<BurstGuardEvent>,
/// This is the period of time the `BurstGuard` will wait for a new trigger to be sent
/// before it calls the callback.
buffer_period: Duration,
/// This is the longest period that the `BurstGuard` will wait from the first trigger till
/// it calls the callback.
longest_buffer_period: Duration,
}

enum BurstGuardEvent {
Trigger,
Trigger(Duration),
Shutdown(Sender<()>),
}

impl BurstGuard {
pub fn new<F: Fn() + Send + 'static>(callback: F) -> Self {
/// This is the period of time the `BurstGuard` will wait for a new trigger to be sent
/// before it calls the callback.
const BURST_BUFFER_PERIOD: Duration = Duration::from_millis(200);
/// This is the longest period that the `BurstGuard` will wait from the first trigger till
/// it calls the callback.
const BURST_LONGEST_BUFFER_PERIOD: Duration = Duration::from_secs(2);

pub fn new<F: Fn() + Send + 'static>(
buffer_period: Duration,
longest_buffer_period: Duration,
callback: F,
) -> Self {
let (sender, listener) = channel();
std::thread::spawn(move || {
// The `stop` implementation assumes that this thread will not call `callback` again
// if the listener has been dropped.
while let Ok(message) = listener.recv() {
match message {
BurstGuardEvent::Trigger => {
BurstGuardEvent::Trigger(mut period) => {
let start = Instant::now();
loop {
match listener.recv_timeout(BURST_BUFFER_PERIOD) {
Ok(BurstGuardEvent::Trigger) => {
if start.elapsed() >= BURST_LONGEST_BUFFER_PERIOD {
match listener.recv_timeout(period) {
Ok(BurstGuardEvent::Trigger(new_period)) => {
period = new_period;
let max_period = std::cmp::max(longest_buffer_period, period);
if start.elapsed() >= max_period {
callback();
break;
}
Expand All @@ -67,7 +72,11 @@ impl BurstGuard {
}
}
});
Self { sender }
Self {
sender,
buffer_period,
longest_buffer_period,
}
}

/// When `stop` returns an then the `BurstGuard` thread is guaranteed to not make any further
Expand All @@ -90,6 +99,13 @@ impl BurstGuard {

/// Asynchronously trigger burst
pub fn trigger(&self) {
self.sender.send(BurstGuardEvent::Trigger).unwrap();
self.trigger_with_period(self.buffer_period)
}

/// Asynchronously trigger burst
pub fn trigger_with_period(&self, buffer_period: Duration) {
self.sender
.send(BurstGuardEvent::Trigger(buffer_period))
.unwrap();
}
}
Loading
Loading