diff --git a/src/lib.rs b/src/lib.rs index a0a56a52..f728bbba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -701,3 +701,232 @@ impl<'name, 'bufs, 'control> fmt::Debug for MsgHdrMut<'name, 'bufs, 'control> { "MsgHdrMut".fmt(fmt) } } + +/// Configuration of a `sendmmsg(2)` system call. +/// +/// This wraps `mmsghdr` on Unix. Also see [`MMsgHdrMut`] for the variant used by `recvmmsg(2)`. +/// This API is not available on Windows. +#[cfg(any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", +))] +pub struct MMsgHdr<'addr, 'bufs, 'control> { + inner: sys::mmsghdr, + #[allow(clippy::type_complexity)] + _lifetimes: PhantomData<(&'addr SockAddr, &'bufs IoSlice<'bufs>, &'control [u8])>, +} + +#[cfg(any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", +))] +impl<'addr, 'bufs, 'control> MMsgHdr<'addr, 'bufs, 'control> { + /// Create a new `MMsgHdr` with all empty/zero fields. + #[allow(clippy::new_without_default)] + pub fn new() -> MMsgHdr<'addr, 'bufs, 'control> { + // SAFETY: all zero is valid for `mmsghdr`. + MMsgHdr { + inner: unsafe { mem::zeroed() }, + _lifetimes: PhantomData, + } + } + + /// Create a new `MMsgHdr` from a `MsgHdr`. + pub fn from_msghdr(msghdr: MsgHdr<'addr, 'bufs, 'control>) -> MMsgHdr<'addr, 'bufs, 'control> { + MMsgHdr { + inner: sys::mmsghdr { + msg_hdr: msghdr.inner, + msg_len: 0, + }, + _lifetimes: PhantomData, + } + } + + /// Set the address (name) of the message. + /// + /// Corresponds to setting `msg_name` and `msg_namelen` on Unix. + pub fn with_addr(mut self, addr: &'addr SockAddr) -> Self { + sys::set_msghdr_name(&mut self.inner.msg_hdr, addr); + self + } + + /// Set the buffer(s) of the message. + /// + /// Corresponds to setting `msg_iov` and `msg_iovlen` on Unix. + pub fn with_buffers(mut self, bufs: &'bufs [IoSlice<'_>]) -> Self { + let ptr = bufs.as_ptr() as *mut _; + sys::set_msghdr_iov(&mut self.inner.msg_hdr, ptr, bufs.len()); + self + } + + /// Set the control buffer of the message. + /// + /// Corresponds to setting `msg_control` and `msg_controllen` on Unix. + pub fn with_control(mut self, buf: &'control [u8]) -> Self { + let ptr = buf.as_ptr() as *mut _; + sys::set_msghdr_control(&mut self.inner.msg_hdr, ptr, buf.len()); + self + } + + /// Set the flags of the message. + /// + /// Corresponds to setting `msg_flags` on Unix. + pub fn with_flags(mut self, flags: sys::c_int) -> Self { + sys::set_msghdr_flags(&mut self.inner.msg_hdr, flags); + self + } + + /// Gets the number of sent bytes. + /// + /// Corresponds to `msg_len` on Unix. + pub fn data_len(&self) -> usize { + self.inner.msg_len as usize + } +} + +#[cfg(any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", +))] +impl<'name, 'bufs, 'control> fmt::Debug for MMsgHdr<'name, 'bufs, 'control> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + "MMsgHdr".fmt(fmt) + } +} + +/// Configuration of a `recvmmsg(2)` system call. +/// +/// This wraps `mmsghdr` on Unix. Also see [`MMsgHdr`] for the variant used by `sendmmsg(2)`. +/// This API is not available on Windows. +#[cfg(any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", +))] +pub struct MMsgHdrMut<'addr, 'bufs, 'control> { + inner: sys::mmsghdr, + #[allow(clippy::type_complexity)] + _lifetimes: PhantomData<( + &'addr mut SockAddr, + &'bufs mut MaybeUninitSlice<'bufs>, + &'control mut [u8], + )>, +} + +#[cfg(any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", +))] +impl<'addr, 'bufs, 'control> MMsgHdrMut<'addr, 'bufs, 'control> { + /// Create a new `MMsgHdrMut` with all empty/zero fields. + #[allow(clippy::new_without_default)] + pub fn new() -> MMsgHdrMut<'addr, 'bufs, 'control> { + // SAFETY: all zero is valid for `mmsghdr`. + MMsgHdrMut { + inner: unsafe { mem::zeroed() }, + _lifetimes: PhantomData, + } + } + + /// Create a new `MMsgHdrMut` from a `MsgHdrMut`. + pub fn from_msghdrmut( + msghdrmut: MsgHdrMut<'addr, 'bufs, 'control>, + ) -> MMsgHdrMut<'addr, 'bufs, 'control> { + MMsgHdrMut { + inner: sys::mmsghdr { + msg_hdr: msghdrmut.inner, + msg_len: 0, + }, + _lifetimes: PhantomData, + } + } + + /// Set the mutable address (name) of the message. + /// + /// Corresponds to setting `msg_name` and `msg_namelen` on Unix. + #[allow(clippy::needless_pass_by_ref_mut)] + pub fn with_addr(mut self, addr: &'addr mut SockAddr) -> Self { + sys::set_msghdr_name(&mut self.inner.msg_hdr, addr); + self + } + + /// Set the mutable buffer(s) of the message. + /// + /// Corresponds to setting `msg_iov` and `msg_iovlen` on Unix. + pub fn with_buffers(mut self, bufs: &'bufs mut [MaybeUninitSlice<'_>]) -> Self { + sys::set_msghdr_iov( + &mut self.inner.msg_hdr, + bufs.as_mut_ptr().cast(), + bufs.len(), + ); + self + } + + /// Set the mutable control buffer of the message. + /// + /// Corresponds to setting `msg_control` and `msg_controllen` on Unix. + pub fn with_control(mut self, buf: &'control mut [MaybeUninit]) -> Self { + sys::set_msghdr_control(&mut self.inner.msg_hdr, buf.as_mut_ptr().cast(), buf.len()); + self + } + + /// Returns the flags of the message. + pub fn flags(&self) -> RecvFlags { + sys::msghdr_flags(&self.inner.msg_hdr) + } + + /// Gets the length of the control buffer. + /// + /// Can be used to determine how much, if any, of the control buffer was filled by `recvmsg`. + /// + /// Corresponds to `msg_controllen` on Unix. + pub fn control_len(&self) -> usize { + sys::msghdr_control_len(&self.inner.msg_hdr) + } + + /// Gets the number of received bytes. + /// + /// Corresponds to `msg_len` on Unix. + pub fn data_len(&self) -> usize { + self.inner.msg_len as usize + } +} + +#[cfg(any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", +))] +impl<'name, 'bufs, 'control> fmt::Debug for MMsgHdrMut<'name, 'bufs, 'control> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + "MMsgHdrMut".fmt(fmt) + } +} diff --git a/src/socket.rs b/src/socket.rs index 1fc6f080..f380c1b4 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -24,6 +24,19 @@ use crate::sys::{self, c_int, getsockopt, setsockopt, Bool}; #[cfg(all(unix, not(target_os = "redox")))] use crate::MsgHdrMut; use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type}; +#[cfg(all( + feature = "all", + any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + ) +))] +use crate::{MMsgHdr, MMsgHdrMut}; #[cfg(not(target_os = "redox"))] use crate::{MaybeUninitSlice, MsgHdr, RecvFlags}; @@ -648,6 +661,30 @@ impl Socket { sys::recvmsg(self.as_raw(), msg, flags) } + /// Receive a list of messages on a socket using a message structure. + /// Note that the timeout is buggy on Linux, see BUGS section in the Linux manual page. + #[doc = man_links!(recvmmsg(2))] + #[cfg(all( + feature = "all", + any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + ) + ))] + pub fn recvmmsg( + &self, + msgs: &mut [MMsgHdrMut<'_, '_, '_>], + flags: sys::c_int, + timeout: Option, + ) -> io::Result { + sys::recvmmsg(self.as_raw(), msgs, flags, timeout) + } + /// Sends data on the socket to a connected peer. /// /// This is typically used on TCP sockets or datagram sockets which have @@ -749,6 +786,28 @@ impl Socket { pub fn sendmsg(&self, msg: &MsgHdr<'_, '_, '_>, flags: sys::c_int) -> io::Result { sys::sendmsg(self.as_raw(), msg, flags) } + + /// Send a list of messages on a socket using a message structure. + #[doc = man_links!(sendmmsg(2))] + #[cfg(all( + feature = "all", + any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + ) + ))] + pub fn sendmmsg( + &self, + msgs: &mut [MMsgHdr<'_, '_, '_>], + flags: sys::c_int, + ) -> io::Result { + sys::sendmmsg(self.as_raw(), msgs, flags) + } } /// Set `SOCK_CLOEXEC` and `NO_HANDLE_INHERIT` on the `ty`pe on platforms that diff --git a/src/sys/unix.rs b/src/sys/unix.rs index 0421bf88..7fbf1247 100644 --- a/src/sys/unix.rs +++ b/src/sys/unix.rs @@ -80,6 +80,19 @@ use libc::ssize_t; use libc::{in6_addr, in_addr}; use crate::{Domain, Protocol, SockAddr, SockAddrStorage, TcpKeepalive, Type}; +#[cfg(all( + feature = "all", + any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + ) +))] +use crate::{MMsgHdr, MMsgHdrMut}; #[cfg(not(target_os = "redox"))] use crate::{MsgHdr, MsgHdrMut, RecvFlags}; @@ -744,6 +757,18 @@ pub(crate) fn msghdr_control_len(msg: &msghdr) -> usize { msg.msg_controllen as _ } +// Used in `MMsgHdr`. +#[cfg(any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", +))] +pub(crate) use libc::mmsghdr; + /// Unix only API. impl SockAddr { /// Constructs a `SockAddr` with the family `AF_VSOCK` and the provided CID/port. @@ -1098,6 +1123,78 @@ pub(crate) fn recvmsg( syscall!(recvmsg(fd, &mut msg.inner, flags)).map(|n| n as usize) } +// helper function for recvmmsg +#[cfg(all( + feature = "all", + any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + ) +))] +fn into_timespec(duration: Duration) -> libc::timespec { + // https://github.com/rust-lang/libc/issues/1848 + #[cfg_attr(any(target_env = "musl", target_env = "ohos"), allow(deprecated))] + libc::timespec { + tv_sec: min(duration.as_secs(), libc::time_t::MAX as u64) as libc::time_t, + tv_nsec: duration.subsec_nanos() as libc::c_long, + } +} + +// type of the parameter specifying the number of mmsghdr elements in sendmmsg/recvmmsg syscalls +#[cfg(all( + feature = "all", + any( + target_os = "aix", + target_os = "android", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + ) +))] +type MMsgHdrLen = libc::c_uint; +#[cfg(all(feature = "all", target_os = "freebsd"))] +type MMsgHdrLen = usize; + +#[cfg(all( + feature = "all", + any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + ) +))] +pub(crate) fn recvmmsg( + fd: Socket, + msgs: &mut [MMsgHdrMut<'_, '_, '_>], + flags: c_int, + timeout: Option, +) -> io::Result { + let mut ts: libc::timespec; + let tp = match timeout { + Some(d) => { + ts = into_timespec(d); + &mut ts + } + None => std::ptr::null_mut(), + }; + // MMsgHdrMut only contains libc::mmsghdr and PhantomData + let mp = msgs.as_mut_ptr() as *mut libc::mmsghdr; + // flags is unsigned in musl and ohos libc + #[cfg(any(target_env = "musl", target_env = "ohos"))] + let flags = flags.cast_unsigned(); + syscall!(recvmmsg(fd, mp, msgs.len() as MMsgHdrLen, flags, tp)).map(|n| n as usize) +} + pub(crate) fn send(fd: Socket, buf: &[u8], flags: c_int) -> io::Result { syscall!(send( fd, @@ -1142,6 +1239,31 @@ pub(crate) fn sendmsg(fd: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io: syscall!(sendmsg(fd, &msg.inner, flags)).map(|n| n as usize) } +#[cfg(all( + feature = "all", + any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + ) +))] +pub(crate) fn sendmmsg( + fd: Socket, + msgs: &mut [MMsgHdr<'_, '_, '_>], + flags: c_int, +) -> io::Result { + // MMsgHdr only contains libc::mmsghdr and PhantomData + let mp = msgs.as_mut_ptr() as *mut libc::mmsghdr; + // flags is unsigned in musl and ohos libc + #[cfg(any(target_env = "musl", target_env = "ohos"))] + let flags = flags.cast_unsigned(); + syscall!(sendmmsg(fd, mp, msgs.len() as MMsgHdrLen, flags)).map(|n| n as usize) +} + /// Wrapper around `getsockopt` to deal with platform specific timeouts. pub(crate) fn timeout_opt(fd: Socket, opt: c_int, val: c_int) -> io::Result> { unsafe { getsockopt(fd, opt, val).map(from_timeval) } @@ -1171,7 +1293,7 @@ pub(crate) fn set_timeout_opt( fn into_timeval(duration: Option) -> libc::timeval { match duration { // https://github.com/rust-lang/libc/issues/1848 - #[cfg_attr(target_env = "musl", allow(deprecated))] + #[cfg_attr(any(target_env = "musl", target_env = "ohos"), allow(deprecated))] Some(duration) => libc::timeval { tv_sec: min(duration.as_secs(), libc::time_t::MAX as u64) as libc::time_t, tv_usec: duration.subsec_micros() as libc::suseconds_t, diff --git a/tests/socket.rs b/tests/socket.rs index 05ba1fe3..e68c490a 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -776,6 +776,120 @@ fn sendmsg() { assert_eq!(received, DATA.len()); } +#[test] +#[cfg(all( + feature = "all", + any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + ) +))] +fn sendmmsg() { + let (socket_a, socket_b) = udp_pair_unconnected(); + + const DATA1: &[u8] = b"Hello, "; + const DATA2: &[u8] = b"World!"; + + let bufs1 = &[IoSlice::new(DATA1)]; + let bufs2 = &[IoSlice::new(DATA2)]; + let addr_b = socket_b.local_addr().unwrap(); + let mut msgs = Vec::new(); + msgs.push( + socket2::MMsgHdr::new() + .with_addr(&addr_b) + .with_buffers(bufs1), + ); + msgs.push( + socket2::MMsgHdr::new() + .with_addr(&addr_b) + .with_buffers(bufs2), + ); + let sent = socket_a.sendmmsg(&mut msgs, 0).unwrap(); + assert_eq!(sent, msgs.len()); + assert_eq!(msgs[0].data_len(), DATA1.len()); + assert_eq!(msgs[1].data_len(), DATA2.len()); + + let mut buf1 = Vec::with_capacity(DATA1.len() + 1); + let mut buf2 = Vec::with_capacity(DATA2.len() + 1); + let received1 = socket_b.recv(buf1.spare_capacity_mut()).unwrap(); + let received2 = socket_b.recv(buf2.spare_capacity_mut()).unwrap(); + assert_eq!(received1, DATA1.len()); + // SAFETY: recv filled the buffer and received1 is not exceeding the capacity + unsafe { + buf1.set_len(received1); + } + assert_eq!(&buf1[..], DATA1); + assert_eq!(received2, DATA2.len()); + // SAFETY: recv filled the buffer and received2 is not exceeding the capacity + unsafe { + buf2.set_len(received2); + } + assert_eq!(&buf2[..], DATA2); +} + +#[test] +#[cfg(all( + feature = "all", + any( + target_os = "aix", + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + ) +))] +fn recvmmsg() { + let (socket_a, socket_b) = udp_pair_unconnected(); + + const DATA1: &[u8] = b"Hello, "; + const DATA2: &[u8] = b"World!"; + + let bufs1 = &[IoSlice::new(DATA1)]; + let bufs2 = &[IoSlice::new(DATA2)]; + let addr_b = socket_b.local_addr().unwrap(); + let msg1 = socket2::MsgHdr::new() + .with_addr(&addr_b) + .with_buffers(bufs1); + let msg2 = socket2::MsgHdr::new() + .with_addr(&addr_b) + .with_buffers(bufs2); + let sent1 = socket_a.sendmsg(&msg1, 0).unwrap(); + let sent2 = socket_a.sendmsg(&msg2, 0).unwrap(); + assert_eq!(sent1, DATA1.len()); + assert_eq!(sent2, DATA2.len()); + + let mut buf1 = Vec::with_capacity(DATA1.len() + 1); + let mut buf2 = Vec::with_capacity(DATA2.len() + 1); + let mut iov1 = [socket2::MaybeUninitSlice::new(buf1.spare_capacity_mut())]; + let mut iov2 = [socket2::MaybeUninitSlice::new(buf2.spare_capacity_mut())]; + let mut msgs = Vec::new(); + msgs.push(socket2::MMsgHdrMut::new().with_buffers(&mut iov1[..])); + msgs.push(socket2::MMsgHdrMut::new().with_buffers(&mut iov2[..])); + let received = socket_b.recvmmsg(&mut msgs, 0, None).unwrap(); + assert_eq!(received, msgs.len()); + let received1 = msgs[0].data_len(); + let received2 = msgs[1].data_len(); + assert_eq!(received1, DATA1.len()); + // SAFETY: recvmmsg filled the buffer and received1 is not exceeding the capacity + unsafe { + buf1.set_len(received1); + } + assert_eq!(received2, DATA2.len()); + // SAFETY: recvmmsg filled the buffer and received1 is not exceeding the capacity + unsafe { + buf2.set_len(received2); + } + assert_eq!(&buf1[..], DATA1); + assert_eq!(&buf2[..], DATA2); +} + #[test] #[cfg(not(any(target_os = "redox", target_os = "vita")))] fn recv_vectored_truncated() {