Skip to content

Commit

Permalink
Merge pull request #2 from sfackler/sockaddr
Browse files Browse the repository at this point in the history
Add a SockAddr type
  • Loading branch information
alexcrichton authored Jun 13, 2017
2 parents f11c451 + 4891ec6 commit 9abb700
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 177 deletions.
27 changes: 20 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
//! # Examples
//!
//! ```no_run
//! use std::net::SocketAddr;
//! use socket2::{Socket, Domain, Type};
//!
//! // create a TCP listener bound to two addresses
//! let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
//!
//! socket.bind(&"127.0.0.1:12345".parse().unwrap()).unwrap();
//! socket.bind(&"127.0.0.1:12346".parse().unwrap()).unwrap();
//! socket.bind(&"127.0.0.1:12345".parse::<SocketAddr>().unwrap().into()).unwrap();
//! socket.bind(&"127.0.0.1:12346".parse::<SocketAddr>().unwrap().into()).unwrap();
//! socket.listen(128).unwrap();
//!
//! let listener = socket.into_tcp_listener();
Expand All @@ -45,6 +46,10 @@

use utils::NetInt;

#[cfg(unix)] use libc::{sockaddr_storage, socklen_t};
#[cfg(windows)] use winapi::{SOCKADDR_STORAGE as sockaddr_storage, socklen_t};

mod sockaddr;
mod socket;
mod utils;

Expand All @@ -63,13 +68,14 @@ mod utils;
/// # Examples
///
/// ```no_run
/// use socket2::{Socket, Domain, Type};
/// use std::net::SocketAddr;
/// use socket2::{Socket, Domain, Type, SockAddr};
///
/// // create a TCP listener bound to two addresses
/// let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
///
/// socket.bind(&"127.0.0.1:12345".parse().unwrap()).unwrap();
/// socket.bind(&"127.0.0.1:12346".parse().unwrap()).unwrap();
/// socket.bind(&"127.0.0.1:12345".parse::<SocketAddr>().unwrap().into()).unwrap();
/// socket.bind(&"127.0.0.1:12346".parse::<SocketAddr>().unwrap().into()).unwrap();
/// socket.listen(128).unwrap();
///
/// let listener = socket.into_tcp_listener();
Expand All @@ -79,6 +85,15 @@ pub struct Socket {
inner: sys::Socket,
}

/// The address of a socket.
///
/// `SockAddr`s may be constructed directly to and from the standard library
/// `SocketAddr`, `SocketAddrV4`, and `SocketAddrV6` types.
pub struct SockAddr {
storage: sockaddr_storage,
len: socklen_t,
}

/// Specification of the communication domain for a socket.
///
/// This is a newtype wrapper around an integer which provides a nicer API in
Expand Down Expand Up @@ -111,5 +126,3 @@ pub struct Type(i32);
pub struct Protocol(i32);

fn hton<I: NetInt>(i: I) -> I { i.to_be() }

fn ntoh<I: NetInt>(i: I) -> I { I::from_be(i) }
139 changes: 139 additions & 0 deletions src/sockaddr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use std::net::{SocketAddrV4, SocketAddrV6, SocketAddr};
use std::mem;
use std::ptr;
use std::fmt;

#[cfg(unix)]
use libc::{sockaddr, sockaddr_storage, sockaddr_in, sockaddr_in6, sa_family_t, socklen_t, AF_INET,
AF_INET6};
#[cfg(windows)]
use winapi::{SOCKADDR as sockaddr, SOCKADDR_STORAGE as sockaddr_storage,
SOCKADDR_IN as sockaddr_in, SOCKADDR_IN6 as sockaddr_in6,
ADDRESS_FAMILY as sa_family_t, socklen_t, AF_INET, AF_INET6};

use SockAddr;

impl fmt::Debug for SockAddr {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
let mut builder = fmt.debug_struct("SockAddr");
builder.field("family", &self.family());
if let Some(addr) = self.as_inet() {
builder.field("inet", &addr);
} else if let Some(addr) = self.as_inet6() {
builder.field("inet6", &addr);
}
builder.finish()
}
}

impl SockAddr {
/// Constructs a `SockAddr` from its raw components.
pub unsafe fn from_raw_parts(addr: *const sockaddr, len: socklen_t) -> SockAddr {
let mut storage = mem::uninitialized::<sockaddr_storage>();
ptr::copy_nonoverlapping(addr as *const _ as *const u8,
&mut storage as *mut _ as *mut u8,
len as usize);

SockAddr {
storage: storage,
len: len,
}
}

unsafe fn as_<T>(&self, family: sa_family_t) -> Option<T> {
if self.storage.ss_family != family {
return None;
}

Some(mem::transmute_copy(&self.storage))
}

/// Returns this address as a `SocketAddrV4` if it is in the `AF_INET`
/// family.
pub fn as_inet(&self) -> Option<SocketAddrV4> {
unsafe { self.as_(AF_INET as sa_family_t) }
}

/// Returns this address as a `SocketAddrV4` if it is in the `AF_INET6`
/// family.
pub fn as_inet6(&self) -> Option<SocketAddrV6> {
unsafe { self.as_(AF_INET6 as sa_family_t) }
}

/// Returns this address's family.
pub fn family(&self) -> sa_family_t {
self.storage.ss_family
}

/// Returns the size of this address in bytes.
pub fn len(&self) -> socklen_t {
self.len
}

/// Returns a raw pointer to the address.
pub fn as_ptr(&self) -> *const sockaddr {
&self.storage as *const _ as *const _
}
}

// SocketAddrV4 and SocketAddrV6 are just wrappers around sockaddr_in and sockaddr_in6

// check to make sure that the sizes at least match up
fn _size_checks(v4: SocketAddrV4, v6: SocketAddrV6) {
unsafe {
mem::transmute::<SocketAddrV4, sockaddr_in>(v4);
mem::transmute::<SocketAddrV6, sockaddr_in6>(v6);
}
}

impl From<SocketAddrV4> for SockAddr {
fn from(addr: SocketAddrV4) -> SockAddr {
unsafe {
SockAddr::from_raw_parts(&addr as *const _ as *const _,
mem::size_of::<SocketAddrV4>() as socklen_t)
}
}
}

impl From<SocketAddrV6> for SockAddr {
fn from(addr: SocketAddrV6) -> SockAddr {
unsafe {
SockAddr::from_raw_parts(&addr as *const _ as *const _,
mem::size_of::<SocketAddrV6>() as socklen_t)
}
}
}

impl From<SocketAddr> for SockAddr {
fn from(addr: SocketAddr) -> SockAddr {
match addr {
SocketAddr::V4(addr) => addr.into(),
SocketAddr::V6(addr) => addr.into(),
}
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn inet() {
let raw = "127.0.0.1:80".parse::<SocketAddrV4>().unwrap();
let addr = SockAddr::from(raw);
assert!(addr.as_inet6().is_none());
let addr = addr.as_inet().unwrap();
assert_eq!(raw, addr);
}

#[test]
fn inet6() {
let raw = "[2001:db8::ff00:42:8329]:80"
.parse::<SocketAddrV6>()
.unwrap();
let addr = SockAddr::from(raw);
assert!(addr.as_inet().is_none());
let addr = addr.as_inet6().unwrap();
assert_eq!(raw, addr);
}
}
28 changes: 15 additions & 13 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

use std::fmt;
use std::io::{self, Read, Write};
use std::net::{self, SocketAddr, Ipv4Addr, Ipv6Addr, Shutdown};
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
use std::time::Duration;

#[cfg(unix)]
Expand All @@ -19,7 +19,7 @@ use libc as c;
use winapi as c;

use sys;
use {Socket, Protocol, Domain, Type};
use {Socket, Protocol, Domain, Type, SockAddr};

impl Socket {
/// Creates a new socket ready to be configured.
Expand Down Expand Up @@ -58,7 +58,7 @@ impl Socket {
///
/// An error will be returned if `listen` or `connect` has already been
/// called on this builder.
pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> {
pub fn connect(&self, addr: &SockAddr) -> io::Result<()> {
self.inner.connect(addr)
}

Expand All @@ -81,15 +81,15 @@ impl Socket {
///
/// If the connection request times out, it may still be processing in the
/// background - a second call to `connect` or `connect_timeout` may fail.
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
pub fn connect_timeout(&self, addr: &SockAddr, timeout: Duration) -> io::Result<()> {
self.inner.connect_timeout(addr, timeout)
}

/// Binds this socket to the specified address.
///
/// This function directly corresponds to the bind(2) function on Windows
/// and Unix.
pub fn bind(&self, addr: &SocketAddr) -> io::Result<()> {
pub fn bind(&self, addr: &SockAddr) -> io::Result<()> {
self.inner.bind(addr)
}

Expand All @@ -110,19 +110,19 @@ impl Socket {
/// This function will block the calling thread until a new connection is
/// established. When established, the corresponding `Socket` and the
/// remote peer's address will be returned.
pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> {
pub fn accept(&self) -> io::Result<(Socket, SockAddr)> {
self.inner.accept().map(|(socket, addr)| {
(Socket { inner: socket }, addr)
})
}

/// Returns the socket address of the local half of this TCP connection.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
pub fn local_addr(&self) -> io::Result<SockAddr> {
self.inner.local_addr()
}

/// Returns the socket address of the remote peer of this TCP connection.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
pub fn peer_addr(&self) -> io::Result<SockAddr> {
self.inner.peer_addr()
}

Expand Down Expand Up @@ -184,7 +184,7 @@ impl Socket {

/// Receives data from the socket. On success, returns the number of bytes
/// read and the address from whence the data came.
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
self.inner.recv_from(buf)
}

Expand All @@ -195,7 +195,7 @@ impl Socket {
///
/// On success, returns the number of bytes peeked and the address from
/// whence the data came.
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
self.inner.peek_from(buf)
}

Expand All @@ -214,7 +214,7 @@ impl Socket {
///
/// This is typically used on UDP or datagram-oriented sockets. On success
/// returns the number of bytes that were sent.
pub fn send_to(&self, buf: &[u8], addr: &SocketAddr) -> io::Result<usize> {
pub fn send_to(&self, buf: &[u8], addr: &SockAddr) -> io::Result<usize> {
self.inner.send_to(buf, addr)
}

Expand Down Expand Up @@ -693,12 +693,14 @@ impl From<Protocol> for i32 {

#[cfg(test)]
mod test {
use std::net::SocketAddr;

use super::*;

#[test]
fn connect_timeout_unrouteable() {
// this IP is unroutable, so connections should always time out
let addr: SocketAddr = "10.255.255.1:80".parse().unwrap();
let addr = "10.255.255.1:80".parse::<SocketAddr>().unwrap().into();

let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
match socket.connect_timeout(&addr, Duration::from_millis(250)) {
Expand All @@ -711,7 +713,7 @@ mod test {
#[test]
fn connect_timeout_valid() {
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
socket.bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
socket.bind(&"127.0.0.1:0".parse::<SocketAddr>().unwrap().into()).unwrap();
socket.listen(128).unwrap();

let addr = socket.local_addr().unwrap();
Expand Down
Loading

0 comments on commit 9abb700

Please sign in to comment.