Skip to content

Commit

Permalink
Merge pull request #116 from liubin/feature/normal-unix-domain-socket
Browse files Browse the repository at this point in the history
add support for normal Unix domain socket
  • Loading branch information
lifupan authored Dec 21, 2021
2 parents 72beb93 + c0d98aa commit dfae1ad
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 34 deletions.
8 changes: 4 additions & 4 deletions src/asynchronous/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ where
fn header_to_buf(mh: MessageHeader) -> Vec<u8> {
let mut buf = vec![0u8; MESSAGE_HEADER_LENGTH];

let mut covbuf: &mut [u8] = &mut buf[..4];
BigEndian::write_u32(&mut covbuf, mh.length);
let mut covbuf: &mut [u8] = &mut buf[4..8];
BigEndian::write_u32(&mut covbuf, mh.stream_id);
let covbuf: &mut [u8] = &mut buf[..4];
BigEndian::write_u32(covbuf, mh.length);
let covbuf: &mut [u8] = &mut buf[4..8];
BigEndian::write_u32(covbuf, mh.stream_id);
buf[8] = mh.type_;
buf[9] = mh.flags;

Expand Down
140 changes: 118 additions & 22 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ use nix::fcntl::{fcntl, FcntlArg, FdFlag, OFlag};
use nix::sys::socket::*;
use std::os::unix::io::RawFd;

#[derive(Debug)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Domain {
Unix,
#[cfg(target_os = "linux")]
AbstractUnix,
#[cfg(target_os = "linux")]
Vsock,
}

Expand All @@ -39,20 +41,38 @@ pub fn do_listen(listener: RawFd) -> Result<()> {
listen(listener, 10).map_err(|e| Error::Socket(e.to_string()))
}

pub fn parse_host(host: &str) -> Result<(Domain, Vec<&str>)> {
pub fn parse_host(host: &str) -> Result<(Domain, &str)> {
let hostv: Vec<&str> = host.trim().split("://").collect();
if hostv.len() != 2 {
return Err(Error::Others(format!("Host {} is not right", host)));
}

let addr = hostv[1];
if addr.is_empty() {
return Err(Error::Others(format!("address {} is empty", addr)));
}

let domain = match &hostv[0].to_lowercase()[..] {
"unix" => Domain::Unix,
"unix" if !addr.starts_with('@') => Domain::Unix,
#[cfg(not(target_os = "linux"))]
"unix" if addr.starts_with('@') => {
return Err(Error::Others(
"Abstract socket is not supported".to_string(),
))
}
#[cfg(target_os = "linux")]
"unix" if addr.starts_with('@') => Domain::AbstractUnix,
#[cfg(target_os = "linux")]
"vsock" => Domain::Vsock,
x => return Err(Error::Others(format!("Scheme {:?} is not supported", x))),
};

Ok((domain, hostv))
#[cfg(target_os = "linux")]
if domain == Domain::AbstractUnix {
return Ok((domain, &addr[1..]));
}

Ok((domain, addr))
}

pub fn set_fd_close_exec(fd: RawFd) -> Result<RawFd> {
Expand All @@ -72,37 +92,47 @@ pub(crate) const SOCK_CLOEXEC: SockFlag = SockFlag::SOCK_CLOEXEC;
pub(crate) const SOCK_CLOEXEC: SockFlag = SockFlag::empty();

#[cfg(target_os = "linux")]
fn make_addr(host: &str) -> Result<UnixAddr> {
UnixAddr::new_abstract(host.as_bytes()).map_err(err_to_others_err!(e, ""))
fn make_addr(domain: Domain, host: &str) -> Result<UnixAddr> {
match domain {
Domain::Unix => UnixAddr::new(host).map_err(err_to_others_err!(e, "")),
Domain::AbstractUnix => {
UnixAddr::new_abstract(host.as_bytes()).map_err(err_to_others_err!(e, ""))
}
Domain::Vsock => Err(Error::Others(
"function make_addr does not support create vsock socket".to_string(),
)),
}
}

#[cfg(not(target_os = "linux"))]
fn make_addr(host: &str) -> Result<UnixAddr> {
fn make_addr(_domain: Domain, host: &str) -> Result<UnixAddr> {
UnixAddr::new(host).map_err(err_to_others_err!(e, ""))
}

fn make_socket(addr: (&str, u32)) -> Result<(RawFd, Domain, SockAddr)> {
let (host, _) = addr;
let (domain, hostv) = parse_host(host)?;

let sockaddr: SockAddr;
let fd: RawFd;
let get_sock_addr = |domain, host| -> Result<(RawFd, SockAddr)> {
let fd = socket(AddressFamily::Unix, SockType::Stream, SOCK_CLOEXEC, None)
.map_err(|e| Error::Socket(e.to_string()))?;

match domain {
Domain::Unix => {
fd = socket(AddressFamily::Unix, SockType::Stream, SOCK_CLOEXEC, None)
.map_err(|e| Error::Socket(e.to_string()))?;
// MacOS doesn't support atomic creation of a socket descriptor with SOCK_CLOEXEC flag,
// so there is a chance of leak if fork + exec happens in between of these calls.
#[cfg(target_os = "macos")]
set_fd_close_exec(fd)?;

// MacOS doesn't support atomic creation of a socket descriptor with SOCK_CLOEXEC flag,
// so there is a chance of leak if fork + exec happens in between of these calls.
#[cfg(target_os = "macos")]
set_fd_close_exec(fd)?;
let sockaddr = SockAddr::Unix(make_addr(domain, host)?);
Ok((fd, sockaddr))
};

sockaddr = SockAddr::Unix(make_addr(hostv[1])?);
}
let (fd, sockaddr) = match domain {
Domain::Unix => get_sock_addr(domain, hostv)?,
#[cfg(target_os = "linux")]
Domain::AbstractUnix => get_sock_addr(domain, hostv)?,
#[cfg(target_os = "linux")]
Domain::Vsock => {
let host_port_v: Vec<&str> = hostv[1].split(':').collect();
let host_port_v: Vec<&str> = hostv.split(':').collect();
if host_port_v.len() != 2 {
return Err(Error::Others(format!(
"Host {} is not right for vsock",
Expand All @@ -112,15 +142,16 @@ fn make_socket(addr: (&str, u32)) -> Result<(RawFd, Domain, SockAddr)> {
let port: u32 = host_port_v[1]
.parse()
.expect("the vsock port is not an number");
fd = socket(
let fd = socket(
AddressFamily::Vsock,
SockType::Stream,
SockFlag::SOCK_CLOEXEC,
None,
)
.map_err(|e| Error::Socket(e.to_string()))?;
let cid = addr.1;
sockaddr = SockAddr::new_vsock(cid, port);
let sockaddr = SockAddr::new_vsock(cid, port);
(fd, sockaddr)
}
};

Expand Down Expand Up @@ -180,3 +211,68 @@ pub const MESSAGE_LENGTH_MAX: usize = 4 << 20;

pub const MESSAGE_TYPE_REQUEST: u8 = 0x1;
pub const MESSAGE_TYPE_RESPONSE: u8 = 0x2;

#[cfg(test)]
mod tests {
use super::parse_host;
use super::Domain;

#[cfg(target_os = "linux")]
#[test]
fn test_parse_host() {
for i in &[
(
"unix:///run/a.sock",
Some(Domain::Unix),
"/run/a.sock",
true,
),
("vsock://8:1024", Some(Domain::Vsock), "8:1024", true),
("Vsock://8:1025", Some(Domain::Vsock), "8:1025", true),
(
"unix://@/run/b.sock",
Some(Domain::AbstractUnix),
"/run/b.sock",
true,
),
("abc:///run/c.sock", None, "", false),
] {
let (input, domain, addr, success) = (i.0, i.1, i.2, i.3);
let r = parse_host(input);
if success {
let (rd, ra) = r.unwrap();
assert_eq!(rd, domain.unwrap());
assert_eq!(ra, addr);
} else {
assert!(r.is_err());
}
}
}

#[cfg(not(target_os = "linux"))]
#[test]
fn test_parse_host() {
for i in &[
(
"unix:///run/a.sock",
Some(Domain::Unix),
"/run/a.sock",
true,
),
("vsock:///run/c.sock", None, "", false),
("Vsock:///run/c.sock", None, "", false),
("unix://@/run/b.sock", None, "", false),
("abc:///run/c.sock", None, "", false),
] {
let (input, domain, addr, success) = (i.0, i.1, i.2, i.3);
let r = parse_host(input);
if success {
let (rd, ra) = r.unwrap();
assert_eq!(rd, domain.unwrap());
assert_eq!(ra, addr);
} else {
assert!(r.is_err());
}
}
}
}
8 changes: 4 additions & 4 deletions src/sync/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ pub fn read_message(fd: RawFd) -> Result<(MessageHeader, Vec<u8>)> {
fn write_message_header(fd: RawFd, mh: MessageHeader) -> Result<()> {
let mut buf = [0u8; MESSAGE_HEADER_LENGTH];

let mut covbuf: &mut [u8] = &mut buf[..4];
BigEndian::write_u32(&mut covbuf, mh.length);
let mut covbuf: &mut [u8] = &mut buf[4..8];
BigEndian::write_u32(&mut covbuf, mh.stream_id);
let covbuf: &mut [u8] = &mut buf[..4];
BigEndian::write_u32(covbuf, mh.length);
let covbuf: &mut [u8] = &mut buf[4..8];
BigEndian::write_u32(covbuf, mh.stream_id);
buf[8] = mh.type_;
buf[9] = mh.flags;

Expand Down
8 changes: 4 additions & 4 deletions src/sync/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ type Receiver = mpsc::Receiver<(Vec<u8>, mpsc::SyncSender<Result<Vec<u8>>>)>;
/// A ttrpc Client (sync).
#[derive(Clone)]
pub struct Client {
fd: RawFd,
_fd: RawFd,
sender_tx: Sender,
client_close: Arc<ClientClose>,
_client_close: Arc<ClientClose>,
}

impl Client {
Expand Down Expand Up @@ -208,9 +208,9 @@ impl Client {
});

Client {
fd,
_fd: fd,
sender_tx,
client_close,
_client_close: client_close,
}
}
pub fn request(&self, req: Request) -> Result<Response> {
Expand Down

0 comments on commit dfae1ad

Please sign in to comment.