From 1ee78bc41b23832e7c1137cf4e22d7d26b8c7bf0 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 11 Jun 2017 20:33:50 -0700 Subject: [PATCH 1/2] Add a SockAddr type We don't want to be stuck only supporting IP sockets --- src/lib.rs | 27 ++++++--- src/sockaddr.rs | 130 ++++++++++++++++++++++++++++++++++++++++++++ src/socket.rs | 28 +++++----- src/sys/unix/mod.rs | 105 ++++++++--------------------------- src/sys/windows.rs | 97 ++++++++------------------------- 5 files changed, 210 insertions(+), 177 deletions(-) create mode 100644 src/sockaddr.rs diff --git a/src/lib.rs b/src/lib.rs index a1b133c2..6dab5998 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,13 +20,14 @@ //! # Examples //! //! ```no_run +//! use std::net::SocketAddr; //! use socket2::{Socket, Domain, Type}; //! //! // create a TCP listener bound to two addresses //! let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap(); //! -//! socket.bind(&"127.0.0.1:12345".parse().unwrap()).unwrap(); -//! socket.bind(&"127.0.0.1:12346".parse().unwrap()).unwrap(); +//! socket.bind(&"127.0.0.1:12345".parse::().unwrap().into()).unwrap(); +//! socket.bind(&"127.0.0.1:12346".parse::().unwrap().into()).unwrap(); //! socket.listen(128).unwrap(); //! //! let listener = socket.into_tcp_listener(); @@ -45,6 +46,10 @@ use utils::NetInt; +#[cfg(unix)] use libc::{sockaddr_storage, socklen_t}; +#[cfg(windows)] use winapi::{SOCKADDR_STORAGE as sockaddr_storage, socklen_t}; + +mod sockaddr; mod socket; mod utils; @@ -63,13 +68,14 @@ mod utils; /// # Examples /// /// ```no_run -/// use socket2::{Socket, Domain, Type}; +/// use std::net::SocketAddr; +/// use socket2::{Socket, Domain, Type, SockAddr}; /// /// // create a TCP listener bound to two addresses /// let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap(); /// -/// socket.bind(&"127.0.0.1:12345".parse().unwrap()).unwrap(); -/// socket.bind(&"127.0.0.1:12346".parse().unwrap()).unwrap(); +/// socket.bind(&"127.0.0.1:12345".parse::().unwrap().into()).unwrap(); +/// socket.bind(&"127.0.0.1:12346".parse::().unwrap().into()).unwrap(); /// socket.listen(128).unwrap(); /// /// let listener = socket.into_tcp_listener(); @@ -79,6 +85,15 @@ pub struct Socket { inner: sys::Socket, } +/// The address of a socket. +/// +/// `SockAddr`s may be constructed directly to and from the standard library +/// `SocketAddr`, `SocketAddrV4`, and `SocketAddrV6` types. +pub struct SockAddr { + storage: sockaddr_storage, + len: socklen_t, +} + /// Specification of the communication domain for a socket. /// /// This is a newtype wrapper around an integer which provides a nicer API in @@ -111,5 +126,3 @@ pub struct Type(i32); pub struct Protocol(i32); fn hton(i: I) -> I { i.to_be() } - -fn ntoh(i: I) -> I { I::from_be(i) } diff --git a/src/sockaddr.rs b/src/sockaddr.rs new file mode 100644 index 00000000..8ac83da8 --- /dev/null +++ b/src/sockaddr.rs @@ -0,0 +1,130 @@ +use std::net::{SocketAddrV4, SocketAddrV6, SocketAddr}; +use std::mem; +use std::ptr; +use std::fmt; + +#[cfg(unix)] +use libc::{sockaddr, sockaddr_storage, sa_family_t, socklen_t, AF_INET, AF_INET6}; +#[cfg(windows)] +use winapi::{SOCKADDR as sockaddr, SOCKADDR_STORAGE as sockaddr_storage, + ADDRESS_FAMILY as sa_family_t, socklen_t, AF_INET, AF_INET6}; + +use SockAddr; + +impl fmt::Debug for SockAddr { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + let mut builder = fmt.debug_struct("SockAddr"); + builder.field("family", &self.family()); + if let Some(addr) = self.as_inet() { + builder.field("inet", &addr); + } else if let Some(addr) = self.as_inet6() { + builder.field("inet6", &addr); + } + builder.finish() + } +} + +impl SockAddr { + /// Constructs a `SockAddr` from its raw components. + pub unsafe fn from_raw_parts(addr: *const sockaddr, len: socklen_t) -> SockAddr { + let mut storage = mem::uninitialized::(); + ptr::copy_nonoverlapping(addr as *const _ as *const u8, + &mut storage as *mut _ as *mut u8, + len as usize); + + SockAddr { + storage: storage, + len: len, + } + } + + unsafe fn as_(&self, family: sa_family_t) -> Option { + if self.storage.ss_family != family { + return None; + } + + Some(mem::transmute_copy(&self.storage)) + } + + /// Returns this address as a `SocketAddrV4` if it is in the `AF_INET` + /// family. + pub fn as_inet(&self) -> Option { + unsafe { self.as_(AF_INET as sa_family_t) } + } + + /// Returns this address as a `SocketAddrV4` if it is in the `AF_INET6` + /// family. + pub fn as_inet6(&self) -> Option { + unsafe { self.as_(AF_INET6 as sa_family_t) } + } + + /// Returns this address's family. + pub fn family(&self) -> sa_family_t { + self.storage.ss_family + } + + /// Returns the size of this address in bytes. + pub fn len(&self) -> socklen_t { + self.len + } + + /// Returns a raw pointer to the address. + pub fn as_ptr(&self) -> *const sockaddr { + &self.storage as *const _ as *const _ + } +} + +// SocketAddrV4 and SocketAddrV6 are just wrappers around sockaddr_in and sockaddr_in6 + +impl From for SockAddr { + fn from(addr: SocketAddrV4) -> SockAddr { + unsafe { + SockAddr::from_raw_parts(&addr as *const _ as *const _, + mem::size_of::() as socklen_t) + } + } +} + + +impl From for SockAddr { + fn from(addr: SocketAddrV6) -> SockAddr { + unsafe { + SockAddr::from_raw_parts(&addr as *const _ as *const _, + mem::size_of::() as socklen_t) + } + } +} + +impl From for SockAddr { + fn from(addr: SocketAddr) -> SockAddr { + match addr { + SocketAddr::V4(addr) => addr.into(), + SocketAddr::V6(addr) => addr.into(), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn inet() { + let raw = "127.0.0.1:80".parse::().unwrap(); + let addr = SockAddr::from(raw); + assert!(addr.as_inet6().is_none()); + let addr = addr.as_inet().unwrap(); + assert_eq!(raw, addr); + } + + #[test] + fn inet6() { + let raw = "[2001:db8::ff00:42:8329]:80" + .parse::() + .unwrap(); + let addr = SockAddr::from(raw); + assert!(addr.as_inet().is_none()); + let addr = addr.as_inet6().unwrap(); + assert_eq!(raw, addr); + } +} diff --git a/src/socket.rs b/src/socket.rs index b66ba9e9..444f5c1a 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -10,7 +10,7 @@ use std::fmt; use std::io::{self, Read, Write}; -use std::net::{self, SocketAddr, Ipv4Addr, Ipv6Addr, Shutdown}; +use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown}; use std::time::Duration; #[cfg(unix)] @@ -19,7 +19,7 @@ use libc as c; use winapi as c; use sys; -use {Socket, Protocol, Domain, Type}; +use {Socket, Protocol, Domain, Type, SockAddr}; impl Socket { /// Creates a new socket ready to be configured. @@ -58,7 +58,7 @@ impl Socket { /// /// An error will be returned if `listen` or `connect` has already been /// called on this builder. - pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> { + pub fn connect(&self, addr: &SockAddr) -> io::Result<()> { self.inner.connect(addr) } @@ -81,7 +81,7 @@ impl Socket { /// /// If the connection request times out, it may still be processing in the /// background - a second call to `connect` or `connect_timeout` may fail. - pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> { + pub fn connect_timeout(&self, addr: &SockAddr, timeout: Duration) -> io::Result<()> { self.inner.connect_timeout(addr, timeout) } @@ -89,7 +89,7 @@ impl Socket { /// /// This function directly corresponds to the bind(2) function on Windows /// and Unix. - pub fn bind(&self, addr: &SocketAddr) -> io::Result<()> { + pub fn bind(&self, addr: &SockAddr) -> io::Result<()> { self.inner.bind(addr) } @@ -110,19 +110,19 @@ impl Socket { /// This function will block the calling thread until a new connection is /// established. When established, the corresponding `Socket` and the /// remote peer's address will be returned. - pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> { + pub fn accept(&self) -> io::Result<(Socket, SockAddr)> { self.inner.accept().map(|(socket, addr)| { (Socket { inner: socket }, addr) }) } /// Returns the socket address of the local half of this TCP connection. - pub fn local_addr(&self) -> io::Result { + pub fn local_addr(&self) -> io::Result { self.inner.local_addr() } /// Returns the socket address of the remote peer of this TCP connection. - pub fn peer_addr(&self) -> io::Result { + pub fn peer_addr(&self) -> io::Result { self.inner.peer_addr() } @@ -184,7 +184,7 @@ impl Socket { /// Receives data from the socket. On success, returns the number of bytes /// read and the address from whence the data came. - pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> { self.inner.recv_from(buf) } @@ -195,7 +195,7 @@ impl Socket { /// /// On success, returns the number of bytes peeked and the address from /// whence the data came. - pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> { self.inner.peek_from(buf) } @@ -214,7 +214,7 @@ impl Socket { /// /// This is typically used on UDP or datagram-oriented sockets. On success /// returns the number of bytes that were sent. - pub fn send_to(&self, buf: &[u8], addr: &SocketAddr) -> io::Result { + pub fn send_to(&self, buf: &[u8], addr: &SockAddr) -> io::Result { self.inner.send_to(buf, addr) } @@ -693,12 +693,14 @@ impl From for i32 { #[cfg(test)] mod test { + use std::net::SocketAddr; + use super::*; #[test] fn connect_timeout_unrouteable() { // this IP is unroutable, so connections should always time out - let addr: SocketAddr = "10.255.255.1:80".parse().unwrap(); + let addr = "10.255.255.1:80".parse::().unwrap().into(); let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap(); match socket.connect_timeout(&addr, Duration::from_millis(250)) { @@ -711,7 +713,7 @@ mod test { #[test] fn connect_timeout_valid() { let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap(); - socket.bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + socket.bind(&"127.0.0.1:0".parse::().unwrap().into()).unwrap(); socket.listen(128).unwrap(); let addr = socket.local_addr().unwrap(); diff --git a/src/sys/unix/mod.rs b/src/sys/unix/mod.rs index 4a0872c2..62c7966d 100644 --- a/src/sys/unix/mod.rs +++ b/src/sys/unix/mod.rs @@ -14,14 +14,14 @@ use std::io::{Read, Write, ErrorKind}; use std::io; use std::mem; use std::net::Shutdown; -use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr}; +use std::net::{self, Ipv4Addr, Ipv6Addr}; use std::ops::Neg; use std::os::unix::prelude::*; use std::sync::atomic::{AtomicBool, Ordering, ATOMIC_BOOL_INIT}; use std::time::{Duration, Instant}; -use libc::{self, c_void, c_int, sockaddr_in, sockaddr_storage, sockaddr_in6}; -use libc::{sockaddr, socklen_t, AF_INET, AF_INET6, ssize_t}; +use libc::{self, c_void, c_int}; +use libc::{sockaddr, socklen_t, ssize_t}; cfg_if! { if #[cfg(any(target_os = "dragonfly", target_os = "freebsd", @@ -57,6 +57,7 @@ cfg_if! { } } +use SockAddr; use utils::One; #[macro_use] @@ -93,15 +94,9 @@ impl Socket { } } - pub fn bind(&self, addr: &SocketAddr) -> io::Result<()> { - #[cfg(not(all(target_arch = "aarch64",target_os = "android")))] - use libc::socklen_t as len_t; - #[cfg(all(target_arch = "aarch64",target_os = "android"))] - use libc::c_int as len_t; - - let (addr, len) = addr2raw(addr); + pub fn bind(&self, addr: &SockAddr) -> io::Result<()> { unsafe { - cvt(libc::bind(self.fd, addr, len as len_t)).map(|_| ()) + cvt(libc::bind(self.fd, addr.as_ptr(), addr.len() as _)).map(|_| ()) } } @@ -111,14 +106,13 @@ impl Socket { } } - pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> { - let (addr, len) = addr2raw(addr); + pub fn connect(&self, addr: &SockAddr) -> io::Result<()> { unsafe { - cvt(libc::connect(self.fd, addr, len)).map(|_| ()) + cvt(libc::connect(self.fd, addr.as_ptr(), addr.len())).map(|_| ()) } } - pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> { + pub fn connect_timeout(&self, addr: &SockAddr, timeout: Duration) -> io::Result<()> { self.set_nonblocking(true)?; let r = self.connect(addr); self.set_nonblocking(false)?; @@ -179,25 +173,25 @@ impl Socket { } } - pub fn local_addr(&self) -> io::Result { + pub fn local_addr(&self) -> io::Result { unsafe { let mut storage: libc::sockaddr_storage = mem::zeroed(); let mut len = mem::size_of_val(&storage) as libc::socklen_t; cvt(libc::getsockname(self.fd, &mut storage as *mut _ as *mut _, &mut len))?; - raw2addr(&storage, len) + Ok(SockAddr::from_raw_parts(&storage as *const _ as *const _, len)) } } - pub fn peer_addr(&self) -> io::Result { + pub fn peer_addr(&self) -> io::Result { unsafe { let mut storage: libc::sockaddr_storage = mem::zeroed(); let mut len = mem::size_of_val(&storage) as libc::socklen_t; cvt(libc::getpeername(self.fd, &mut storage as *mut _ as *mut _, &mut len))?; - raw2addr(&storage, len) + Ok(SockAddr::from_raw_parts(&storage as *const _ as *const _, len)) } } @@ -233,7 +227,7 @@ impl Socket { } #[allow(unused_mut)] - pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> { + pub fn accept(&self) -> io::Result<(Socket, SockAddr)> { let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; let mut len = mem::size_of_val(&storage) as socklen_t; @@ -270,7 +264,7 @@ impl Socket { fd } }; - let addr = raw2addr(&storage, len)?; + let addr = unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, len) }; Ok((socket, addr)) } @@ -334,16 +328,16 @@ impl Socket { } } - pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> { self.recvfrom(buf, 0) } - pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> { self.recvfrom(buf, libc::MSG_PEEK) } fn recvfrom(&self, buf: &mut [u8], flags: c_int) - -> io::Result<(usize, SocketAddr)> { + -> io::Result<(usize, SockAddr)> { unsafe { let mut storage: libc::sockaddr_storage = mem::zeroed(); let mut addrlen = mem::size_of_val(&storage) as socklen_t; @@ -356,7 +350,8 @@ impl Socket { &mut storage as *mut _ as *mut _, &mut addrlen) })?; - Ok((n as usize, raw2addr(&storage, addrlen)?)) + let addr = SockAddr::from_raw_parts(&storage as *const _ as *const _, addrlen); + Ok((n as usize, addr)) } } @@ -372,16 +367,15 @@ impl Socket { } } - pub fn send_to(&self, buf: &[u8], addr: &SocketAddr) -> io::Result { + pub fn send_to(&self, buf: &[u8], addr: &SockAddr) -> io::Result { unsafe { - let (addr, len) = addr2raw(addr); let n = cvt({ libc::sendto(self.fd, buf.as_ptr() as *const c_void, cmp::min(buf.len(), max_len()), MSG_NOSIGNAL, - addr, - len) + addr.as_ptr(), + addr.len()) })?; Ok(n as usize) } @@ -887,59 +881,6 @@ fn set_cloexec(fd: c_int) -> io::Result<()> { } } -fn addr2raw(addr: &SocketAddr) -> (*const sockaddr, socklen_t) { - match *addr { - SocketAddr::V4(ref a) => { - (a as *const _ as *const _, mem::size_of_val(a) as socklen_t) - } - SocketAddr::V6(ref a) => { - (a as *const _ as *const _, mem::size_of_val(a) as socklen_t) - } - } -} - -fn raw2addr(storage: &sockaddr_storage, len: socklen_t) -> io::Result { - match storage.ss_family as c_int { - AF_INET => { - unsafe { - assert!(len as usize >= mem::size_of::()); - let sa = storage as *const _ as *const sockaddr_in; - let bits = ::ntoh((*sa).sin_addr.s_addr); - let ip = Ipv4Addr::new((bits >> 24) as u8, - (bits >> 16) as u8, - (bits >> 8) as u8, - bits as u8); - Ok(SocketAddr::V4(SocketAddrV4::new(ip, ::ntoh((*sa).sin_port)))) - } - } - AF_INET6 => { - unsafe { - assert!(len as usize >= mem::size_of::()); - - let sa = storage as *const _ as *const sockaddr_in6; - let arr = (*sa).sin6_addr.s6_addr; - - let ip = Ipv6Addr::new( - (arr[0] as u16) << 8 | (arr[1] as u16), - (arr[2] as u16) << 8 | (arr[3] as u16), - (arr[4] as u16) << 8 | (arr[5] as u16), - (arr[6] as u16) << 8 | (arr[7] as u16), - (arr[8] as u16) << 8 | (arr[9] as u16), - (arr[10] as u16) << 8 | (arr[11] as u16), - (arr[12] as u16) << 8 | (arr[13] as u16), - (arr[14] as u16) << 8 | (arr[15] as u16), - ); - - Ok(SocketAddr::V6(SocketAddrV6::new(ip, - ::ntoh((*sa).sin6_port), - (*sa).sin6_flowinfo, - (*sa).sin6_scope_id))) - } - } - _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid argument")), - } -} - fn dur2timeval(dur: Option) -> io::Result { match dur { Some(dur) => { diff --git a/src/sys/windows.rs b/src/sys/windows.rs index bbefac39..21038e3f 100644 --- a/src/sys/windows.rs +++ b/src/sys/windows.rs @@ -14,7 +14,7 @@ use std::io::{Read, Write}; use std::io; use std::mem; use std::net::Shutdown; -use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr}; +use std::net::{self, Ipv4Addr, Ipv6Addr}; use std::os::windows::prelude::*; use std::ptr; use std::sync::{Once, ONCE_INIT}; @@ -24,6 +24,8 @@ use kernel32; use winapi::*; use ws2_32; +use SockAddr; + const HANDLE_FLAG_INHERIT: DWORD = 0x00000001; const MSG_PEEK: c_int = 0x2; const SD_BOTH: c_int = 2; @@ -77,10 +79,9 @@ impl Socket { } } - pub fn bind(&self, addr: &SocketAddr) -> io::Result<()> { - let (addr, len) = addr2raw(addr); + pub fn bind(&self, addr: &SockAddr) -> io::Result<()> { unsafe { - if ws2_32::bind(self.socket, addr, len) == 0 { + if ws2_32::bind(self.socket, addr.as_ptr(), addr.len()) == 0 { Ok(()) } else { Err(last_error()) @@ -98,10 +99,9 @@ impl Socket { } } - pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> { - let (addr, len) = addr2raw(addr); + pub fn connect(&self, addr: &SockAddr) -> io::Result<()> { unsafe { - if ws2_32::connect(self.socket, addr, len) == 0 { + if ws2_32::connect(self.socket, addr.as_ptr(), addr.len()) == 0 { Ok(()) } else { Err(last_error()) @@ -109,7 +109,7 @@ impl Socket { } } - pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> { + pub fn connect_timeout(&self, addr: &SockAddr, timeout: Duration) -> io::Result<()> { self.set_nonblocking(true)?; let r = self.connect(addr); self.set_nonblocking(true)?; @@ -157,7 +157,7 @@ impl Socket { } } - pub fn local_addr(&self) -> io::Result { + pub fn local_addr(&self) -> io::Result { unsafe { let mut storage: SOCKADDR_STORAGE = mem::zeroed(); let mut len = mem::size_of_val(&storage) as c_int; @@ -166,11 +166,11 @@ impl Socket { &mut len) != 0 { return Err(last_error()) } - raw2addr(&storage, len) + Ok(SockAddr::from_raw_parts(&storage as *const _ as *const _, len)) } } - pub fn peer_addr(&self) -> io::Result { + pub fn peer_addr(&self) -> io::Result { unsafe { let mut storage: SOCKADDR_STORAGE = mem::zeroed(); let mut len = mem::size_of_val(&storage) as c_int; @@ -179,7 +179,7 @@ impl Socket { &mut len) != 0 { return Err(last_error()) } - raw2addr(&storage, len) + Ok(SockAddr::from_raw_parts(&storage as *const _ as *const _, len)) } } @@ -207,7 +207,7 @@ impl Socket { } } - pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> { + pub fn accept(&self) -> io::Result<(Socket, SockAddr)> { unsafe { let mut storage: SOCKADDR_STORAGE = mem::zeroed(); let mut len = mem::size_of_val(&storage) as c_int; @@ -221,7 +221,7 @@ impl Socket { socket => Socket::from_raw_socket(socket), }; socket.set_no_inherit()?; - let addr = raw2addr(&storage, len)?; + let addr = SockAddr::from_raw_parts(&storage as *const _ as *const _, len); Ok((socket, addr)) } } @@ -297,16 +297,16 @@ impl Socket { } } - pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> { self.recvfrom(buf, 0) } - pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> { self.recvfrom(buf, MSG_PEEK) } fn recvfrom(&self, buf: &mut [u8], flags: c_int) - -> io::Result<(usize, SocketAddr)> { + -> io::Result<(usize, SockAddr)> { unsafe { let mut storage: SOCKADDR_STORAGE = mem::zeroed(); let mut addrlen = mem::size_of_val(&storage) as c_int; @@ -324,7 +324,8 @@ impl Socket { SOCKET_ERROR => return Err(last_error()), n => n as usize, }; - Ok((n, raw2addr(&storage, addrlen)?)) + let addr = SockAddr::from_raw_parts(&storage as *const _ as *const _, addrlen); + Ok((n, addr)) } } @@ -344,16 +345,15 @@ impl Socket { } } - pub fn send_to(&self, buf: &[u8], addr: &SocketAddr) -> io::Result { + pub fn send_to(&self, buf: &[u8], addr: &SockAddr) -> io::Result { unsafe { - let (addr, len) = addr2raw(addr); let n = { ws2_32::sendto(self.socket, buf.as_ptr() as *const c_char, clamp(buf.len()), 0, - addr, - len) + addr.as_ptr(), + addr.len()) }; if n == SOCKET_ERROR { Err(last_error()) @@ -837,59 +837,6 @@ fn clamp(input: usize) -> c_int { cmp::min(input, ::max_value() as usize) as c_int } -fn addr2raw(addr: &SocketAddr) -> (*const SOCKADDR, c_int) { - match *addr { - SocketAddr::V4(ref a) => { - (a as *const _ as *const _, mem::size_of_val(a) as c_int) - } - SocketAddr::V6(ref a) => { - (a as *const _ as *const _, mem::size_of_val(a) as c_int) - } - } -} - -fn raw2addr(storage: &SOCKADDR_STORAGE, len: c_int) -> io::Result { - match storage.ss_family as c_int { - AF_INET => { - unsafe { - assert!(len as usize >= mem::size_of::()); - let sa = storage as *const _ as *const SOCKADDR_IN; - let bits = ::ntoh((*sa).sin_addr.S_un); - let ip = Ipv4Addr::new((bits >> 24) as u8, - (bits >> 16) as u8, - (bits >> 8) as u8, - bits as u8); - Ok(SocketAddr::V4(SocketAddrV4::new(ip, ::ntoh((*sa).sin_port)))) - } - } - AF_INET6 => { - unsafe { - assert!(len as usize >= mem::size_of::()); - - let sa = storage as *const _ as *const sockaddr_in6; - let arr = (*sa).sin6_addr.s6_addr; - - let ip = Ipv6Addr::new( - (arr[0] as u16) << 8 | (arr[1] as u16), - (arr[2] as u16) << 8 | (arr[3] as u16), - (arr[4] as u16) << 8 | (arr[5] as u16), - (arr[6] as u16) << 8 | (arr[7] as u16), - (arr[8] as u16) << 8 | (arr[9] as u16), - (arr[10] as u16) << 8 | (arr[11] as u16), - (arr[12] as u16) << 8 | (arr[13] as u16), - (arr[14] as u16) << 8 | (arr[15] as u16), - ); - - Ok(SocketAddr::V6(SocketAddrV6::new(ip, - ::ntoh((*sa).sin6_port), - (*sa).sin6_flowinfo, - (*sa).sin6_scope_id))) - } - } - _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid argument")), - } -} - fn dur2ms(dur: Option) -> io::Result { match dur { Some(dur) => { From 4891ec633d7a3784d400ce11f2461e35ec038e53 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Mon, 12 Jun 2017 19:34:46 -0700 Subject: [PATCH 2/2] Add a static size check --- src/sockaddr.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/sockaddr.rs b/src/sockaddr.rs index 8ac83da8..e27df364 100644 --- a/src/sockaddr.rs +++ b/src/sockaddr.rs @@ -4,9 +4,11 @@ use std::ptr; use std::fmt; #[cfg(unix)] -use libc::{sockaddr, sockaddr_storage, sa_family_t, socklen_t, AF_INET, AF_INET6}; +use libc::{sockaddr, sockaddr_storage, sockaddr_in, sockaddr_in6, sa_family_t, socklen_t, AF_INET, + AF_INET6}; #[cfg(windows)] use winapi::{SOCKADDR as sockaddr, SOCKADDR_STORAGE as sockaddr_storage, + SOCKADDR_IN as sockaddr_in, SOCKADDR_IN6 as sockaddr_in6, ADDRESS_FAMILY as sa_family_t, socklen_t, AF_INET, AF_INET6}; use SockAddr; @@ -76,6 +78,14 @@ impl SockAddr { // SocketAddrV4 and SocketAddrV6 are just wrappers around sockaddr_in and sockaddr_in6 +// check to make sure that the sizes at least match up +fn _size_checks(v4: SocketAddrV4, v6: SocketAddrV6) { + unsafe { + mem::transmute::(v4); + mem::transmute::(v6); + } +} + impl From for SockAddr { fn from(addr: SocketAddrV4) -> SockAddr { unsafe { @@ -85,7 +95,6 @@ impl From for SockAddr { } } - impl From for SockAddr { fn from(addr: SocketAddrV6) -> SockAddr { unsafe {