Skip to content

Commit

Permalink
Hide logic for WgGoTunnel connectivty check
Browse files Browse the repository at this point in the history
  • Loading branch information
kl committed Nov 18, 2024
1 parent 456fbba commit 2fbbed6
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 152 deletions.
243 changes: 148 additions & 95 deletions talpid-wireguard/src/connectivity_check.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#[cfg(target_os = "android")]
use super::Tunnel;
use super::{TunnelError, TunnelType};
#[cfg(target_os = "android")]
use crate::wireguard_go::WgGoTunnel;
use crate::{
ping_monitor::{new_pinger, Pinger},
stats::StatsMap,
Expand Down Expand Up @@ -75,36 +73,50 @@ pub enum Error {
/// monitor has started pinging and no traffic has been received for a duration of `PING_TIMEOUT`.
pub struct ConnectivityMonitor {
conn_state: ConnState,
initial_ping_timestamp: Option<Instant>,
num_pings_sent: u32,
pinger: Box<dyn Pinger>,
close_receiver: mpsc::Receiver<()>,
ping_state: PingState,
close_receiver: Option<mpsc::Receiver<()>>,
}

impl ConnectivityMonitor {
pub(super) fn new(
addr: Ipv4Addr,
#[cfg(any(target_os = "macos", target_os = "linux"))] interface: String,
close_receiver: mpsc::Receiver<()>,
) -> Result<Self, Error> {
let pinger = new_pinger(
addr,
#[cfg(any(target_os = "macos", target_os = "linux"))]
interface,
)
.map_err(Error::PingError)?;

let now = Instant::now();

Ok(Self {
conn_state: ConnState::new(now, Default::default()),
initial_ping_timestamp: None,
num_pings_sent: 0,
pinger,
close_receiver,
conn_state: ConnState::new(Instant::now(), Default::default()),
ping_state: PingState::new(
addr,
#[cfg(any(target_os = "macos", target_os = "linux"))]
interface,
)?,
close_receiver: None,
})
}

pub(super) fn with_close_receiver(self, close_receiver: mpsc::Receiver<()>) -> Self {
Self {
close_receiver: Some(close_receiver),
..self
}
}

/// Returns true if monitor should be shut down
fn should_shut_down(&mut self, timeout: Duration) -> bool {
let Some(close_receiver) = self.close_receiver.as_ref() else {
return false;
};

match close_receiver.recv_timeout(timeout) {
Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true,
Err(mpsc::RecvTimeoutError::Timeout) => false,
}
}

fn reset(&mut self, current_iteration: Instant) {
self.ping_state.reset();
self.conn_state.reset_after_suspension(current_iteration);
}

// checks if the tunnel has ever worked. Intended to check if a connection to a tunnel is
// successful at the start of a connection.
pub(super) fn establish_connectivity(
Expand All @@ -113,7 +125,10 @@ impl ConnectivityMonitor {
tunnel_handle: &TunnelType,
) -> Result<bool, Error> {
// Send initial ping to prod WireGuard into connecting.
self.pinger.send_icmp().map_err(Error::PingError)?;
self.ping_state
.pinger
.send_icmp()
.map_err(Error::PingError)?;
self.establish_connectivity_inner(
retry_attempt,
ESTABLISH_TIMEOUT,
Expand Down Expand Up @@ -152,59 +167,6 @@ impl ConnectivityMonitor {
Ok(false)
}

pub(super) fn run(
&mut self,
tunnel_handle: Weak<Mutex<Option<TunnelType>>>,
) -> Result<(), Error> {
self.wait_loop(REGULAR_LOOP_SLEEP, tunnel_handle)
}

/// Returns true if monitor should be shut down
fn should_shut_down(&mut self, timeout: Duration) -> bool {
match self.close_receiver.recv_timeout(timeout) {
Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true,
Err(mpsc::RecvTimeoutError::Timeout) => false,
}
}

fn wait_loop(
&mut self,
iter_delay: Duration,
tunnel_handle: Weak<Mutex<Option<TunnelType>>>,
) -> Result<(), Error> {
let mut last_iteration = Instant::now();
while !self.should_shut_down(iter_delay) {
let mut current_iteration = Instant::now();
let time_slept = current_iteration - last_iteration;
if time_slept < (iter_delay * 2) {
let Some(tunnel) = tunnel_handle.upgrade() else {
return Ok(());
};
let lock = tunnel.blocking_lock();
let Some(tunnel) = lock.as_ref() else {
return Ok(());
};

if !self.check_connectivity(Instant::now(), tunnel)? {
return Ok(());
}
drop(lock);

let end = Instant::now();
if end - current_iteration > Duration::from_secs(1) {
current_iteration = end;
}
} else {
// Loop was suspended for too long, so it's safer to assume that the host still has
// connectivity.
self.reset_pinger();
self.conn_state.reset_after_suspension(current_iteration);
}
last_iteration = current_iteration;
}
Ok(())
}

/// Returns true if connection is established
fn check_connectivity(
&mut self,
Expand All @@ -227,12 +189,12 @@ impl ConnectivityMonitor {
let new_stats = new_stats?;

if self.conn_state.update(now, new_stats) {
self.reset_pinger();
self.ping_state.reset();
return Ok(true);
}

self.maybe_send_ping(now)?;
Ok(!self.ping_timed_out(timeout) && self.conn_state.connected())
Ok(!self.ping_state.ping_timed_out(timeout) && self.conn_state.connected())
}
}
}
Expand All @@ -258,33 +220,25 @@ impl ConnectivityMonitor {
// 3 seconds.
if (self.conn_state.rx_timed_out() || self.conn_state.traffic_timed_out())
&& self
.ping_state
.initial_ping_timestamp
.map(|initial_ping_timestamp| {
initial_ping_timestamp.elapsed() / self.num_pings_sent < SECONDS_PER_PING
initial_ping_timestamp.elapsed() / self.ping_state.num_pings_sent
< SECONDS_PER_PING
})
.unwrap_or(true)
{
self.pinger.send_icmp().map_err(Error::PingError)?;
if self.initial_ping_timestamp.is_none() {
self.initial_ping_timestamp = Some(now);
self.ping_state
.pinger
.send_icmp()
.map_err(Error::PingError)?;
if self.ping_state.initial_ping_timestamp.is_none() {
self.ping_state.initial_ping_timestamp = Some(now);
}
self.num_pings_sent += 1;
self.ping_state.num_pings_sent += 1;
}
Ok(())
}

fn ping_timed_out(&self, timeout: Duration) -> bool {
self.initial_ping_timestamp
.map(|initial_ping_timestamp| initial_ping_timestamp.elapsed() > timeout)
.unwrap_or(false)
}

/// Reset timeouts - assume that the last time bytes were received is now.
fn reset_pinger(&mut self) {
self.initial_ping_timestamp = None;
self.num_pings_sent = 0;
self.pinger.reset();
}
}

enum ConnState {
Expand All @@ -300,6 +254,45 @@ enum ConnState {
},
}

struct PingState {
initial_ping_timestamp: Option<Instant>,
num_pings_sent: u32,
pinger: Box<dyn Pinger>,
}

impl PingState {
pub(super) fn new(
addr: Ipv4Addr,
#[cfg(any(target_os = "macos", target_os = "linux"))] interface: String,
) -> Result<Self, Error> {
let pinger = new_pinger(
addr,
#[cfg(any(target_os = "macos", target_os = "linux"))]
interface,
)
.map_err(Error::PingError)?;

Ok(Self {
initial_ping_timestamp: None,
num_pings_sent: 0,
pinger,
})
}

fn ping_timed_out(&self, timeout: Duration) -> bool {
self.initial_ping_timestamp
.map(|initial_ping_timestamp| initial_ping_timestamp.elapsed() > timeout)
.unwrap_or(false)
}

/// Reset timeouts - assume that the last time bytes were received is now.
fn reset(&mut self) {
self.initial_ping_timestamp = None;
self.num_pings_sent = 0;
self.pinger.reset();
}
}

impl ConnState {
pub fn new(start: Instant, stats: StatsMap) -> Self {
ConnState::Connecting {
Expand Down Expand Up @@ -418,6 +411,66 @@ impl ConnState {
}
}

pub struct ConnectivityMonitorLoop {
connectivity_monitor: ConnectivityMonitor,
}

impl ConnectivityMonitorLoop {
pub(super) fn new(connectivity_monitor: ConnectivityMonitor) -> Self {
debug_assert!(
connectivity_monitor.close_receiver.is_some(),
"Close receiver must be set"
);
Self {
connectivity_monitor,
}
}

pub(super) fn run(self, tunnel_handle: Weak<Mutex<Option<TunnelType>>>) -> Result<(), Error> {
self.wait_loop(REGULAR_LOOP_SLEEP, tunnel_handle)
}

fn wait_loop(
mut self,
iter_delay: Duration,
tunnel_handle: Weak<Mutex<Option<TunnelType>>>,
) -> Result<(), Error> {
let mut last_iteration = Instant::now();
while !self.connectivity_monitor.should_shut_down(iter_delay) {
let mut current_iteration = Instant::now();
let time_slept = current_iteration - last_iteration;
if time_slept < (iter_delay * 2) {
let Some(tunnel) = tunnel_handle.upgrade() else {
return Ok(());
};
let lock = tunnel.blocking_lock();
let Some(tunnel) = lock.as_ref() else {
return Ok(());
};

if !self
.connectivity_monitor
.check_connectivity(Instant::now(), tunnel)?
{
return Ok(());
}
drop(lock);

let end = Instant::now();
if end - current_iteration > Duration::from_secs(1) {
current_iteration = end;
}
} else {
// Loop was suspended for too long, so it's safer to assume that the host still has
// connectivity.
self.connectivity_monitor.reset(current_iteration);
}
last_iteration = current_iteration;
}
Ok(())
}
}

#[cfg(test)]
mod test {
use futures::Future;
Expand Down
Loading

0 comments on commit 2fbbed6

Please sign in to comment.