Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Use VsockAddr instead of cid and port where appropiate #40

Merged
merged 1 commit into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ impl VsockListener {
}

/// Create a new Virtio socket listener associated with this event loop.
pub fn bind(cid: u32, port: u32) -> Result<Self> {
let l = vsock::VsockListener::bind_with_cid_port(cid, port)?;
pub fn bind(addr: VsockAddr) -> Result<Self> {
let l = vsock::VsockListener::bind_with_cid_port(addr.cid(), addr.port())?;
Self::new(l)
}

Expand Down
6 changes: 2 additions & 4 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ impl VsockStream {
}

/// Open a connection to a remote host.
pub async fn connect(cid: u32, port: u32) -> Result<Self> {
let vsock_addr = VsockAddr::new(cid, port);

pub async fn connect(addr: VsockAddr) -> Result<Self> {
let socket = unsafe { socket(AF_VSOCK, SOCK_STREAM | SOCK_CLOEXEC, 0) };
if socket < 0 {
return Err(Error::last_os_error());
Expand All @@ -88,7 +86,7 @@ impl VsockStream {
if unsafe {
connect(
socket,
&vsock_addr as *const _ as *const sockaddr,
&addr as *const _ as *const sockaddr,
size_of::<sockaddr_vm>() as socklen_t,
)
} < 0
Expand Down
6 changes: 3 additions & 3 deletions test_server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
use clap::{crate_authors, crate_version, App, Arg};
use futures::StreamExt as _;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_vsock::VsockListener;
use tokio_vsock::{VsockListener, VsockAddr};

/// A simple Virtio socket server that uses Hyper to response to requests.
#[tokio::main]
Expand All @@ -42,8 +42,8 @@ async fn main() -> Result<(), ()> {
.parse::<u32>()
.expect("port must be a valid integer");

let listener = VsockListener::bind(libc::VMADDR_CID_ANY, listen_port)
.expect("unable to bind virtio listener");
let addr = VsockAddr::new(libc::VMADDR_CID_ANY, listen_port);
let listener = VsockListener::bind(addr).expect("unable to bind virtio listener");

println!("Listening for connections on port: {}", listen_port);

Expand Down
26 changes: 11 additions & 15 deletions tests/vsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
use rand::RngCore;
use sha2::{Digest, Sha256};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_vsock::{VsockListener, VsockStream};
use tokio_vsock::{VsockAddr, VsockListener, VsockStream};

const TEST_BLOB_SIZE: usize = 100_000;
const TEST_BLOCK_SIZE: usize = 5_000;
Expand All @@ -39,9 +39,8 @@ async fn test_vsock_server() {
rx_blob.resize(TEST_BLOB_SIZE, 0);
rng.fill_bytes(&mut blob);

let mut stream = VsockStream::connect(3, 8000)
.await
.expect("connection failed");
let addr = VsockAddr::new(3, 8000);
let mut stream = VsockStream::connect(addr).await.expect("connection failed");

while tx_pos < TEST_BLOB_SIZE {
let written_bytes = stream
Expand Down Expand Up @@ -75,7 +74,8 @@ async fn test_vsock_server() {

#[tokio::test]
async fn test_vsock_conn_error() {
let err = VsockStream::connect(3, 8001)
let addr = VsockAddr::new(3, 8001);
let err = VsockStream::connect(addr)
.await
.expect_err("connection succeeded")
.raw_os_error()
Expand All @@ -94,8 +94,8 @@ async fn split_vsock() {
const MSG: &[u8] = b"split";
const PORT: u32 = 8002;

let mut listener =
VsockListener::bind(tokio_vsock::VMADDR_CID_LOCAL, PORT).expect("connection failed");
let addr = VsockAddr::new(tokio_vsock::VMADDR_CID_LOCAL, PORT);
let mut listener = VsockListener::bind(addr).expect("connection failed");

let handle = tokio::task::spawn(async move {
let (mut stream, _) = listener
Expand All @@ -115,9 +115,7 @@ async fn split_vsock() {
assert_eq!(&read_buf[..read_len], MSG);
});

let mut stream = VsockStream::connect(tokio_vsock::VMADDR_CID_LOCAL, PORT)
.await
.expect("connection failed");
let mut stream = VsockStream::connect(addr).await.expect("connection failed");
let (mut read_half, mut write_half) = stream.split();

let mut read_buf = [0u8; 32];
Expand Down Expand Up @@ -145,8 +143,8 @@ async fn into_split_vsock() {
const MSG: &[u8] = b"split";
const PORT: u32 = 8001;

let mut listener =
VsockListener::bind(tokio_vsock::VMADDR_CID_LOCAL, PORT).expect("connection failed");
let addr = VsockAddr::new(tokio_vsock::VMADDR_CID_LOCAL, PORT);
let mut listener = VsockListener::bind(addr).expect("connection failed");

let handle = tokio::task::spawn(async move {
let (mut stream, _) = listener
Expand All @@ -166,9 +164,7 @@ async fn into_split_vsock() {
assert_eq!(&read_buf[..read_len], MSG);
});

let stream = VsockStream::connect(tokio_vsock::VMADDR_CID_LOCAL, PORT)
.await
.expect("connection failed");
let stream = VsockStream::connect(addr).await.expect("connection failed");
let (mut read_half, mut write_half) = stream.into_split();

let mut read_buf = [0u8; 32];
Expand Down
Loading