Skip to content

Commit

Permalink
Track last error atomically
Browse files Browse the repository at this point in the history
Eliminate the need for UdpSocket to be passed mutably into send_mmsg().
  • Loading branch information
Scott Hutton committed Jul 10, 2024
1 parent 4b1e816 commit 8ccd2ac
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 26 deletions.
10 changes: 4 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
use std::{
net::{IpAddr, Ipv6Addr, SocketAddr},
sync::atomic::{AtomicUsize, Ordering},
time::{Duration, Instant},
};

pub use crate::cmsg::{AsPtr, EcnCodepoint, Source, Transmit};
use imp::LastSendError;
use tracing::warn;

mod cmsg;
Expand Down Expand Up @@ -94,20 +94,18 @@ impl Default for RecvMeta {
}

/// Log at most 1 IO error per minute
const IO_ERROR_LOG_INTERVAL: Duration = std::time::Duration::from_secs(60);
const IO_ERROR_LOG_INTERVAL: u64 = 60;

/// Logs a warning message when sendmsg fails
///
/// Logging will only be performed if at least [`IO_ERROR_LOG_INTERVAL`]
/// has elapsed since the last error was logged.
fn log_sendmsg_error<B: AsPtr<u8>>(
last_send_error: &mut Instant,
last_send_error: LastSendError,
err: impl core::fmt::Debug,
transmit: &Transmit<B>,
) {
let now = Instant::now();
if now.saturating_duration_since(*last_send_error) > IO_ERROR_LOG_INTERVAL {
*last_send_error = now;
if last_send_error.should_log() {
warn!(
"sendmsg error: {:?}, Transmit: {{ destination: {:?}, src_ip: {:?}, enc: {:?}, len: {:?}, segment_size: {:?} }}",
err, transmit.dst, transmit.src, transmit.ecn, transmit.contents.len(), transmit.segment_size);
Expand Down
75 changes: 55 additions & 20 deletions src/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use std::{
mem::{self, MaybeUninit},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
os::{fd::AsFd, unix::io::AsRawFd},
sync::atomic::AtomicUsize,
sync::{
atomic::{AtomicU64, AtomicUsize, Ordering},
Arc,
},
task::{Context, Poll},
time::Instant,
time::SystemTime,
};

use crate::cmsg::{AsPtr, EcnCodepoint, Source, Transmit};
Expand All @@ -31,7 +34,7 @@ type IpTosTy = libc::c_int;
#[derive(Debug)]
pub struct UdpSocket {
io: tokio::net::UdpSocket,
last_send_error: Instant,
last_send_error: LastSendError,
}

impl AsRawFd for UdpSocket {
Expand All @@ -46,16 +49,47 @@ impl AsFd for UdpSocket {
}
}

#[derive(Clone, Debug)]
pub(crate) struct LastSendError(Arc<AtomicU64>);

impl Default for LastSendError {
fn default() -> Self {
let now = Self::now();
Self(Arc::new(AtomicU64::new(
now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now),
)))
}
}

impl LastSendError {
fn now() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs()
}

// Determine whether the last error was more than IO_ERROR_LOG_INTERVAL
// seconds ago. If so, update the last error time and return true.
pub(crate) fn should_log(&self) -> bool {
let now = Self::now();
self.0
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |cur| {
((now - cur) > IO_ERROR_LOG_INTERVAL).then_some(now)
})
.is_ok()
}
}

impl UdpSocket {
/// Creates a new UDP socket from a previously created `std::net::UdpSocket`
pub fn from_std(socket: std::net::UdpSocket) -> io::Result<UdpSocket> {
socket.set_nonblocking(true)?;

init(SockRef::from(&socket))?;
let now = Instant::now();
Ok(UdpSocket {
io: tokio::net::UdpSocket::from_std(socket)?,
last_send_error: now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now),
last_send_error: LastSendError::default(),
})
}

Expand All @@ -67,10 +101,9 @@ impl UdpSocket {
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
let io = tokio::net::UdpSocket::bind(addr).await?;
init(SockRef::from(&io))?;
let now = Instant::now();
Ok(UdpSocket {
io,
last_send_error: now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now),
last_send_error: LastSendError::default(),
})
}

Expand Down Expand Up @@ -195,13 +228,13 @@ impl UdpSocket {
///
/// [`sendmmsg`]: https://linux.die.net/man/2/sendmmsg
pub async fn send_mmsg<B: AsPtr<u8>>(
&mut self,
&self,
state: &UdpState,
transmits: &[Transmit<B>],
) -> Result<usize, io::Error> {
let n = loop {
self.io.writable().await?;
let last_send_error = &mut self.last_send_error;
let last_send_error = self.last_send_error.clone();
let io = &self.io;
match io.try_io(Interest::WRITABLE, || {
send(state, SockRef::from(io), last_send_error, transmits)
Expand Down Expand Up @@ -278,11 +311,15 @@ impl UdpSocket {
transmits: &[Transmit<B>],
) -> Poll<io::Result<usize>> {
loop {
let last_send_error = &mut self.last_send_error;
ready!(self.io.poll_send_ready(cx))?;
let io = &self.io;
if let Ok(res) = io.try_io(Interest::WRITABLE, || {
send(state, SockRef::from(io), last_send_error, transmits)
send(
state,
SockRef::from(io),
self.last_send_error.clone(),
transmits,
)
}) {
return Poll::Ready(Ok(res));
}
Expand Down Expand Up @@ -353,7 +390,7 @@ pub mod sync {
#[derive(Debug)]
pub struct UdpSocket {
io: std::net::UdpSocket,
last_send_error: Instant,
last_send_error: LastSendError,
}

impl AsRawFd for UdpSocket {
Expand All @@ -372,21 +409,19 @@ pub mod sync {
pub fn from_std(socket: std::net::UdpSocket) -> io::Result<Self> {
init(SockRef::from(&socket))?;
socket.set_nonblocking(false)?;
let now = Instant::now();
Ok(Self {
io: socket,
last_send_error: now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now),
last_send_error: LastSendError::default(),
})
}
/// create a new UDP socket and attempt to bind to `addr`
pub fn bind<A: std::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
let io = std::net::UdpSocket::bind(addr)?;
init(SockRef::from(&io))?;
io.set_nonblocking(false)?;
let now = Instant::now();
Ok(Self {
io,
last_send_error: now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now),
last_send_error: LastSendError::default(),
})
}
/// sets nonblocking mode
Expand Down Expand Up @@ -474,7 +509,7 @@ pub mod sync {
send(
state,
SockRef::from(&self.io),
&mut self.last_send_error,
self.last_send_error.clone(),
transmits,
)
}
Expand Down Expand Up @@ -681,7 +716,7 @@ fn send_msg<B: AsPtr<u8>>(
fn send<B: AsPtr<u8>>(
state: &UdpState,
io: SockRef<'_>,
last_send_error: &mut Instant,
last_send_error: LastSendError,
transmits: &[Transmit<B>],
) -> io::Result<usize> {
use std::ptr;
Expand Down Expand Up @@ -802,7 +837,7 @@ fn send_msg<B: AsPtr<u8>>(
fn send<B: AsPtr<u8>>(
_state: &UdpState,
io: SockRef<'_>,
last_send_error: &mut Instant,
last_send_error: LastSendError,
transmits: &[Transmit<B>],
) -> io::Result<usize> {
let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
Expand All @@ -828,7 +863,7 @@ fn send<B: AsPtr<u8>>(
// Those are not fatal errors, since the
// configuration can be dynamically changed.
// - Destination unreachable errors have been observed for other
log_sendmsg_error(last_send_error, e, &transmits[sent]);
log_sendmsg_error(last_send_error.clone(), e, &transmits[sent]);
sent += 1;
}
}
Expand Down

0 comments on commit 8ccd2ac

Please sign in to comment.