Skip to content

Commit

Permalink
Add socket2 integration for address handling and retrieval (#4716)
Browse files Browse the repository at this point in the history
  • Loading branch information
masa-koz authored Dec 19, 2024
1 parent 62ecc16 commit a03d5c2
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 49 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ bitfield = "0.17.0"
libc = "0.2.0"
c-types = "4.0.0"
serde = { version = "1.0.117", features = ["derive"] }
socket2 = "0.5.8"
146 changes: 97 additions & 49 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@ use c_types::AF_INET;
use c_types::AF_INET6;
#[allow(unused_imports)]
use c_types::AF_UNSPEC;
use c_types::{sa_family_t, sockaddr_in, sockaddr_in6, socklen_t};
use libc::c_void;
use serde::{Deserialize, Serialize};
use socket2::SockAddr;
use std::convert::TryInto;
use std::fmt;
use std::io;
use std::mem;
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::option::Option;
use std::ptr;
use std::result::Result;
Expand Down Expand Up @@ -40,27 +45,6 @@ pub const ADDRESS_FAMILY_INET: AddressFamily = c_types::AF_INET as u16;
#[allow(clippy::unnecessary_cast)]
pub const ADDRESS_FAMILY_INET6: AddressFamily = c_types::AF_INET6 as u16;

/// IPv4 address payload.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct sockaddr_in {
pub family: AddressFamily,
pub port: u16,
pub addr: u32,
pub zero: [u8; 8usize],
}

/// IPv6 address payload.
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct sockaddr_in6 {
pub family: AddressFamily,
pub port: u16,
pub flow_info: u32,
pub addr: [u8; 16usize],
pub scope_id: u32,
}

/// Generic representation of IPv4 or IPv6 addresses.
#[repr(C)]
#[derive(Copy, Clone)]
Expand All @@ -70,40 +54,66 @@ pub union Addr {
}

impl Addr {
/// Create a representation of IPv4 address and perform Network byte order conversion
/// on the port number.
pub fn ipv4(family: u16, port: u16, addr: u32) -> Addr {
Addr {
ipv4: sockaddr_in {
family,
port,
addr,
zero: [0, 0, 0, 0, 0, 0, 0, 0],
},
/// Converts the `Addr` to a `SocketAddr`.
pub fn as_socket(&self) -> Option<SocketAddr> {
unsafe {
SockAddr::try_init(|addr, len| {
if self.ipv4.sin_family == AF_INET as sa_family_t {
let addr = addr.cast::<sockaddr_in>();
*addr = self.ipv4;
*len = mem::size_of::<sockaddr_in>() as socklen_t;
Ok(())
} else if self.ipv4.sin_family == AF_INET6 as sa_family_t {
let addr = addr.cast::<sockaddr_in6>();
*addr = self.ipv6;
*len = mem::size_of::<sockaddr_in6>() as socklen_t;
Ok(())
} else {
Err(io::Error::from(io::ErrorKind::Other))
}
})
}
.map(|((), addr)| addr.as_socket().unwrap())
.ok()
}

/// Create a representation of IPv6 address and perform Network byte order conversion
/// on the port number.
pub fn ipv6(
family: u16,
port: u16,
flow_info: u32,
addr: [u8; 16usize],
scope_id: u32,
) -> Addr {
Addr {
ipv6: sockaddr_in6 {
family,
port,
flow_info,
addr,
scope_id,
},
/// Get port number from the `Addr`.
pub fn port(&self) -> u16 {
unsafe { u16::from_be(self.ipv4.sin_port) }
}
}

impl From<SocketAddr> for Addr {
fn from(addr: SocketAddr) -> Addr {
match addr {
SocketAddr::V4(addr) => addr.into(),
SocketAddr::V6(addr) => addr.into(),
}
}
}

impl From<SocketAddrV4> for Addr {
fn from(addr: SocketAddrV4) -> Addr {
// SAFETY: a `Addr` of all zeros is valid.
let mut storage = unsafe { mem::zeroed::<Addr>() };
let addr: SockAddr = addr.into();
let addr = addr.as_ptr().cast::<sockaddr_in>();
storage.ipv4 = unsafe { *addr };
storage
}
}

impl From<SocketAddrV6> for Addr {
fn from(addr: SocketAddrV6) -> Addr {
// SAFETY: a `Addr` of all zeros is valid.
let mut storage = unsafe { mem::zeroed::<Addr>() };
let addr: SockAddr = addr.into();
let addr = addr.as_ptr().cast::<sockaddr_in6>();
storage.ipv6 = unsafe { *addr };
storage
}
}

#[cfg(target_os = "windows")]
mod status {
pub const QUIC_STATUS_SUCCESS: u32 = 0x0;
Expand Down Expand Up @@ -1666,6 +1676,40 @@ impl Connection {
}
Ok(())
}

pub fn get_local_addr(&self) -> Result<Addr, u32> {
let mut addr_buffer: [u8; mem::size_of::<Addr>()] = [0; mem::size_of::<Addr>()];
let addr_size_mut = mem::size_of::<Addr>();
let status = unsafe {
((*self.table).get_param)(
self.handle,
PARAM_CONN_LOCAL_ADDRESS,
(&addr_size_mut) as *const usize as *const u32 as *mut u32,
addr_buffer.as_mut_ptr() as *const c_void,
)
};
if Status::failed(status) {
return Err(status);
}
Ok(unsafe { *(addr_buffer.as_ptr() as *const c_void as *const Addr) })
}

pub fn get_remote_addr(&self) -> Result<Addr, u32> {
let mut addr_buffer: [u8; mem::size_of::<Addr>()] = [0; mem::size_of::<Addr>()];
let addr_size_mut = mem::size_of::<Addr>();
let status = unsafe {
((*self.table).get_param)(
self.handle,
PARAM_CONN_REMOTE_ADDRESS,
(&addr_size_mut) as *const usize as *const u32 as *mut u32,
addr_buffer.as_mut_ptr() as *const c_void,
)
};
if Status::failed(status) {
return Err(status);
}
Ok(unsafe { *(addr_buffer.as_ptr() as *const c_void as *const Addr) })
}
}

impl Drop for Connection {
Expand Down Expand Up @@ -1820,7 +1864,11 @@ extern "C" fn test_conn_callback(
) -> u32 {
let connection = unsafe { &*(context as *const Connection) };
match event.event_type {
CONNECTION_EVENT_CONNECTED => println!("Connected"),
CONNECTION_EVENT_CONNECTED => {
let local_addr = connection.get_local_addr().unwrap().as_socket().unwrap();
let remote_addr = connection.get_remote_addr().unwrap().as_socket().unwrap();
println!("Connected({}, {})", local_addr, remote_addr);
}
CONNECTION_EVENT_SHUTDOWN_INITIATED_BY_TRANSPORT => {
println!("Transport shutdown 0x{:x}", unsafe {
event.payload.shutdown_initiated_by_transport.status
Expand Down

0 comments on commit a03d5c2

Please sign in to comment.