Skip to content

Commit

Permalink
Merge pull request #40 from rust-vsock/use-vsockaddr
Browse files Browse the repository at this point in the history
fix: Use VsockAddr instead of cid and port where appropiate
  • Loading branch information
jalil-salame authored Dec 7, 2023
2 parents d7a6ef0 + d411193 commit f0aeda6
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ impl VsockListener {
}

/// Create a new Virtio socket listener associated with this event loop.
pub fn bind(cid: u32, port: u32) -> Result<Self> {
let l = vsock::VsockListener::bind_with_cid_port(cid, port)?;
pub fn bind(addr: VsockAddr) -> Result<Self> {
let l = vsock::VsockListener::bind_with_cid_port(addr.cid(), addr.port())?;
Self::new(l)
}

Expand Down
6 changes: 2 additions & 4 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ impl VsockStream {
}

/// Open a connection to a remote host.
pub async fn connect(cid: u32, port: u32) -> Result<Self> {
let vsock_addr = VsockAddr::new(cid, port);

pub async fn connect(addr: VsockAddr) -> Result<Self> {
let socket = unsafe { socket(AF_VSOCK, SOCK_STREAM | SOCK_CLOEXEC, 0) };
if socket < 0 {
return Err(Error::last_os_error());
Expand All @@ -88,7 +86,7 @@ impl VsockStream {
if unsafe {
connect(
socket,
&vsock_addr as *const _ as *const sockaddr,
&addr as *const _ as *const sockaddr,
size_of::<sockaddr_vm>() as socklen_t,
)
} < 0
Expand Down
6 changes: 3 additions & 3 deletions test_server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
use clap::{crate_authors, crate_version, App, Arg};
use futures::StreamExt as _;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_vsock::VsockListener;
use tokio_vsock::{VsockListener, VsockAddr};

/// A simple Virtio socket server that uses Hyper to response to requests.
#[tokio::main]
Expand All @@ -42,8 +42,8 @@ async fn main() -> Result<(), ()> {
.parse::<u32>()
.expect("port must be a valid integer");

let listener = VsockListener::bind(libc::VMADDR_CID_ANY, listen_port)
.expect("unable to bind virtio listener");
let addr = VsockAddr::new(libc::VMADDR_CID_ANY, listen_port);
let listener = VsockListener::bind(addr).expect("unable to bind virtio listener");

println!("Listening for connections on port: {}", listen_port);

Expand Down
26 changes: 11 additions & 15 deletions tests/vsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
use rand::RngCore;
use sha2::{Digest, Sha256};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_vsock::{VsockListener, VsockStream};
use tokio_vsock::{VsockAddr, VsockListener, VsockStream};

const TEST_BLOB_SIZE: usize = 100_000;
const TEST_BLOCK_SIZE: usize = 5_000;
Expand All @@ -39,9 +39,8 @@ async fn test_vsock_server() {
rx_blob.resize(TEST_BLOB_SIZE, 0);
rng.fill_bytes(&mut blob);

let mut stream = VsockStream::connect(3, 8000)
.await
.expect("connection failed");
let addr = VsockAddr::new(3, 8000);
let mut stream = VsockStream::connect(addr).await.expect("connection failed");

while tx_pos < TEST_BLOB_SIZE {
let written_bytes = stream
Expand Down Expand Up @@ -75,7 +74,8 @@ async fn test_vsock_server() {

#[tokio::test]
async fn test_vsock_conn_error() {
let err = VsockStream::connect(3, 8001)
let addr = VsockAddr::new(3, 8001);
let err = VsockStream::connect(addr)
.await
.expect_err("connection succeeded")
.raw_os_error()
Expand All @@ -94,8 +94,8 @@ async fn split_vsock() {
const MSG: &[u8] = b"split";
const PORT: u32 = 8002;

let mut listener =
VsockListener::bind(tokio_vsock::VMADDR_CID_LOCAL, PORT).expect("connection failed");
let addr = VsockAddr::new(tokio_vsock::VMADDR_CID_LOCAL, PORT);
let mut listener = VsockListener::bind(addr).expect("connection failed");

let handle = tokio::task::spawn(async move {
let (mut stream, _) = listener
Expand All @@ -115,9 +115,7 @@ async fn split_vsock() {
assert_eq!(&read_buf[..read_len], MSG);
});

let mut stream = VsockStream::connect(tokio_vsock::VMADDR_CID_LOCAL, PORT)
.await
.expect("connection failed");
let mut stream = VsockStream::connect(addr).await.expect("connection failed");
let (mut read_half, mut write_half) = stream.split();

let mut read_buf = [0u8; 32];
Expand Down Expand Up @@ -145,8 +143,8 @@ async fn into_split_vsock() {
const MSG: &[u8] = b"split";
const PORT: u32 = 8001;

let mut listener =
VsockListener::bind(tokio_vsock::VMADDR_CID_LOCAL, PORT).expect("connection failed");
let addr = VsockAddr::new(tokio_vsock::VMADDR_CID_LOCAL, PORT);
let mut listener = VsockListener::bind(addr).expect("connection failed");

let handle = tokio::task::spawn(async move {
let (mut stream, _) = listener
Expand All @@ -166,9 +164,7 @@ async fn into_split_vsock() {
assert_eq!(&read_buf[..read_len], MSG);
});

let stream = VsockStream::connect(tokio_vsock::VMADDR_CID_LOCAL, PORT)
.await
.expect("connection failed");
let stream = VsockStream::connect(addr).await.expect("connection failed");
let (mut read_half, mut write_half) = stream.into_split();

let mut read_buf = [0u8; 32];
Expand Down

0 comments on commit f0aeda6

Please sign in to comment.