diff --git a/forwarder/Cargo.toml b/forwarder/Cargo.toml index f66c986..e0b9c5c 100644 --- a/forwarder/Cargo.toml +++ b/forwarder/Cargo.toml @@ -8,5 +8,5 @@ anyhow = "1.0.71" log = "0.4.20" etherparse = "0.13.0" socket2 = { version = "0.5.5", features = ["all"] } -mio = { version = "1.0.2", features = ["net", "os-ext", "os-poll"] } +mio = { version = "1.0.2", features = ["net", "os-poll"] } parking_lot = "0.12.3" diff --git a/forwarder/src/lib.rs b/forwarder/src/lib.rs index 30c9b79..0671944 100644 --- a/forwarder/src/lib.rs +++ b/forwarder/src/lib.rs @@ -150,9 +150,9 @@ fn try_cleanup(peer_manager: &RwLock) { for peer in peers.get_all() { let used = peer.reset_used(); if !used { - let client_addr = peer.get_client_addr(); + let client_addr = *peer.get_client_addr(); log::info!("cleaning peer that handled '{client_addr}'"); - if let Err(error) = peers.remove_peer(&peer) { + if let Err(error) = peers.remove_peer(peer) { log::warn!("couldn't remove peer of '{client_addr}': {error:?}"); } } else { diff --git a/forwarder/src/peer.rs b/forwarder/src/peer.rs index 9ced9d9..f933ae1 100644 --- a/forwarder/src/peer.rs +++ b/forwarder/src/peer.rs @@ -76,13 +76,13 @@ impl PeerManager { } } - pub fn add_peer(&mut self, new_peer: Peer) -> anyhow::Result> { + pub fn add_peer(&mut self, mut new_peer: Peer) -> anyhow::Result> { let client_addr = new_peer.client_addr; + self.registry.register(&mut new_peer.socket)?; let peer = Arc::new(new_peer); self.client_addr_to_peers.insert(client_addr, peer.clone()); let peer_port = peer.socket.local_addr()?.port(); self.port_to_peers.insert(peer_port, peer.clone()); - self.registry.register(&peer.socket)?; Ok(peer) } @@ -100,10 +100,13 @@ impl PeerManager { self.client_addr_to_peers.values().cloned().collect() } - pub fn remove_peer(&mut self, peer: &Peer) -> anyhow::Result<()> { - self.registry.deregister(&peer.socket)?; + pub fn remove_peer(&mut self, peer: Arc) -> anyhow::Result<()> { self.client_addr_to_peers.remove(&peer.client_addr); self.port_to_peers.remove(&peer.socket.local_addr()?.port()); + + let mut peer = + Arc::try_unwrap(peer).map_err(|_| anyhow::anyhow!("can't unwrap Arc"))?; + self.registry.deregister(&mut peer.socket)?; Ok(()) } } diff --git a/forwarder/src/poll.rs b/forwarder/src/poll.rs index 773a247..6d866a9 100644 --- a/forwarder/src/poll.rs +++ b/forwarder/src/poll.rs @@ -25,8 +25,8 @@ pub trait Poll: Send { /// trait that allows others to register socket to `Poll` pub trait Registry: Send + Sync { // need Sync because parking_lot::RwLock needs inner to be Sync - fn register(&self, socket: &NonBlockingSocket) -> anyhow::Result<()>; - fn deregister(&self, socket: &NonBlockingSocket) -> anyhow::Result<()>; + fn register(&self, socket: &mut NonBlockingSocket) -> anyhow::Result<()>; + fn deregister(&self, socket: &mut NonBlockingSocket) -> anyhow::Result<()>; } mod icmp; diff --git a/forwarder/src/poll/icmp.rs b/forwarder/src/poll/icmp.rs index ae5e392..e768dd3 100644 --- a/forwarder/src/poll/icmp.rs +++ b/forwarder/src/poll/icmp.rs @@ -51,10 +51,10 @@ impl Poll for IcmpPoll { pub struct IcmpRegistry; // icmp doesn't need a registry because we manage it's poll ourself impl Registry for IcmpRegistry { - fn register(&self, _socket: &NonBlockingSocket) -> anyhow::Result<()> { + fn register(&self, _socket: &mut NonBlockingSocket) -> anyhow::Result<()> { Ok(()) } - fn deregister(&self, _socket: &NonBlockingSocket) -> anyhow::Result<()> { + fn deregister(&self, _socket: &mut NonBlockingSocket) -> anyhow::Result<()> { Ok(()) } } diff --git a/forwarder/src/poll/udp.rs b/forwarder/src/poll/udp.rs index e8e5fbf..8144968 100644 --- a/forwarder/src/poll/udp.rs +++ b/forwarder/src/poll/udp.rs @@ -4,7 +4,7 @@ use crate::{ socket::{NonBlockingSocket, NonBlockingSocketTrait}, MAX_PACKET_SIZE, }; -use mio::{unix::SourceFd, Events, Interest, Token}; +use mio::{Events, Interest, Token}; use parking_lot::RwLock; use std::sync::Arc; @@ -50,22 +50,20 @@ impl Poll for UdpPoll { #[derive(Debug)] pub struct UdpRegistry(pub mio::Registry); impl Registry for UdpRegistry { - fn register(&self, socket: &NonBlockingSocket) -> anyhow::Result<()> { - let socket = socket.as_udp().unwrap(); + fn register(&self, socket: &mut NonBlockingSocket) -> anyhow::Result<()> { + let socket = socket.as_mut_udp().unwrap(); let local_port = socket.local_addr()?.port(); self.0.register( - &mut SourceFd(&socket.as_raw_fd()), + socket.as_inner(), Token(local_port.into()), Interest::READABLE, )?; Ok(()) } - fn deregister(&self, socket: &NonBlockingSocket) -> anyhow::Result<()> { - let socket = socket.as_udp().unwrap(); - let raw_fd = socket.as_raw_fd(); - let source = &mut SourceFd(&raw_fd); - self.0.deregister(source)?; + fn deregister(&self, socket: &mut NonBlockingSocket) -> anyhow::Result<()> { + let socket = socket.as_mut_udp().unwrap(); + self.0.deregister(socket.as_inner())?; Ok(()) } } diff --git a/forwarder/src/socket.rs b/forwarder/src/socket.rs index 8d083f4..1ab16d7 100644 --- a/forwarder/src/socket.rs +++ b/forwarder/src/socket.rs @@ -68,7 +68,7 @@ impl NonBlockingSocket { Ok(socket) } - pub fn as_udp(&self) -> Option<&udp::NonBlockingUdpSocket> { + pub fn as_mut_udp(&mut self) -> Option<&mut udp::NonBlockingUdpSocket> { match self { Self::Udp(inner) => Some(inner), _ => None, diff --git a/forwarder/src/socket/udp.rs b/forwarder/src/socket/udp.rs index d37b2c4..34bf666 100644 --- a/forwarder/src/socket/udp.rs +++ b/forwarder/src/socket/udp.rs @@ -1,9 +1,5 @@ use super::{NonBlockingSocketTrait, SocketTrait}; -use std::{ - io, - net::SocketAddr, - os::fd::{AsRawFd, RawFd}, -}; +use std::{io, net::SocketAddr}; #[derive(Debug)] pub struct UdpSocket(std::net::UdpSocket); @@ -30,17 +26,16 @@ impl SocketTrait for UdpSocket { } #[derive(Debug)] -pub struct NonBlockingUdpSocket(std::net::UdpSocket); +pub struct NonBlockingUdpSocket(mio::net::UdpSocket); impl NonBlockingUdpSocket { pub fn bind(address: &SocketAddr) -> io::Result { - let socket = std::net::UdpSocket::bind(address)?; - socket.set_nonblocking(true)?; + let socket = mio::net::UdpSocket::bind(*address)?; Ok(Self(socket)) } - pub fn as_raw_fd(&self) -> RawFd { - self.0.as_raw_fd() + pub fn as_inner(&mut self) -> &mut mio::net::UdpSocket { + &mut self.0 } } @@ -50,7 +45,7 @@ impl NonBlockingSocketTrait for NonBlockingUdpSocket { } fn connect(&mut self, addr: &SocketAddr) -> io::Result<()> { - self.0.connect(addr) + self.0.connect(*addr) } fn recv(&self, buffer: &mut [u8]) -> io::Result {