diff --git a/boringtun/src/device/epoll.rs b/boringtun/src/device/epoll.rs index b6ecaf0b..7fb86dfa 100644 --- a/boringtun/src/device/epoll.rs +++ b/boringtun/src/device/epoll.rs @@ -50,6 +50,17 @@ struct Event { impl Drop for EventPoll { fn drop(&mut self) { + // The only other struct that holds EventRef is the Device, which will necessarily be + // dropped before EventPoll, so closing all fds is fine. + let events = self.events.lock(); + events + .iter() + .enumerate() + .filter(|(_i, o)| o.is_some()) // This is inefficient but shouldn't be a problem + .for_each(|(i, _o)| { + unsafe { close(i as _) }; + }); + unsafe { close(self.epoll) }; } } @@ -80,7 +91,7 @@ impl EventPoll { /// When triggered, one of the threads waiting on the poll will receive the handler via an /// appropriate EventGuard. It is guaranteed that only a single thread can have a reference to /// the handler at any given time. - pub fn new_event(&self, trigger: RawFd, handler: H) -> Result { + pub fn new_event(&self, trigger: RawFd, handler: H) -> Result<(), Error> { // Create an event descriptor let flags = EPOLLIN | EPOLLONESHOT; let ev = Event { @@ -94,14 +105,15 @@ impl EventPoll { needs_read: false, }; - self.register_event(ev) + self.register_event(ev)?; + Ok(()) } /// Add and enable a new write event with the factory. /// The event is triggered when a Write operation on the provided trigger becomes possible /// For TCP sockets it means that the socket was succesfully connected #[allow(dead_code)] - pub fn new_write_event(&self, trigger: RawFd, handler: H) -> Result { + pub fn new_write_event(&self, trigger: RawFd, handler: H) -> Result<(), Error> { // Create an event descriptor let flags = EPOLLOUT | EPOLLET | EPOLLONESHOT; let ev = Event { @@ -115,13 +127,14 @@ impl EventPoll { needs_read: false, }; - self.register_event(ev) + self.register_event(ev)?; + Ok(()) } /// Add and enable a new timed event with the factory. /// The even will be triggered for the first time after period time, and henceforth triggered /// every period time. Period is counted from the moment the appropriate EventGuard is released. - pub fn new_periodic_event(&self, handler: H, period: Duration) -> Result { + pub fn new_periodic_event(&self, handler: H, period: Duration) -> Result<(), Error> { // The periodic event on Linux uses the timerfd let tfd = match unsafe { timerfd_create(CLOCK_BOOTTIME, TFD_NONBLOCK) } { -1 => match unsafe { timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK) } { @@ -158,7 +171,8 @@ impl EventPoll { needs_read: true, }; - self.register_event(ev) + self.register_event(ev)?; + Ok(()) } /// Add and enable a new notification event with the factory. @@ -191,7 +205,7 @@ impl EventPoll { } /// Add and enable a new signal handler - pub fn new_signal_event(&self, signal: c_int, handler: H) -> Result { + pub fn new_signal_event(&self, signal: c_int, handler: H) -> Result<(), Error> { let sfd = match unsafe { let mut sigset = std::mem::zeroed(); sigemptyset(&mut sigset); @@ -214,7 +228,8 @@ impl EventPoll { needs_read: true, }; - self.register_event(ev) + self.register_event(ev)?; + Ok(()) } /// Wait until one of the registered events becomes triggered. Once an event diff --git a/boringtun/src/device/integration_tests/mod.rs b/boringtun/src/device/integration_tests/mod.rs index b4e360c3..5f8caf5a 100644 --- a/boringtun/src/device/integration_tests/mod.rs +++ b/boringtun/src/device/integration_tests/mod.rs @@ -846,4 +846,41 @@ mod tests { t.join().unwrap(); } } + + #[cfg(target_os = "linux")] + fn count_fds() -> u16 { + let entries = std::fs::read_dir("/proc/self/fd/").unwrap(); + let mut i = 0; + for entry in entries { + // This will panic if EventPoll's Drop happens at the same time + let _entry = entry.unwrap(); + + i += 1; + } + i + } + + #[cfg(target_os = "linux")] + #[test] + fn test_fd_leaks() { + let n_before = count_fds(); + let queue = { + let wg = WGHandle::init("192.0.2.0".parse().unwrap(), "::2".parse().unwrap()); + let response = wg.wg_get(); + assert!(response.ends_with("errno=0\n\n")); + let device = wg._device.device.read(); + device.queue.clone() + }; + + // WGHandle._device (DeviceHandle) triggers an exit notice to the queue. When Device is dropped the queue + // (Arc>) will eventually be dropped too. Only when the queue is dropped will the fds be + // closed. Wait until we are the last reference and then drop the queue. + while Arc::strong_count(&queue) > 1 { + std::thread::yield_now(); + } + drop(queue); // dropping fds now + + let n_after = count_fds(); + assert_eq!(n_before, n_after); + } }