Skip to content

Commit

Permalink
Move connect into SystemTcpSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
badeend committed Jan 11, 2024
1 parent 775021e commit b8ae1cf
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 113 deletions.
161 changes: 49 additions & 112 deletions crates/wasi/src/preview2/host/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ use crate::preview2::network::SocketAddrUse;
use crate::preview2::pipe::{AsyncReadStream, AsyncWriteStream};
use crate::preview2::tcp::SystemTcpSocket;
use crate::preview2::{
InputStream, OutputStream, Pollable, SocketAddrFamily, SocketResult, Subscribe, WasiView,
DynFuture, InputStream, OutputStream, Pollable, SocketAddrFamily, SocketResult, Subscribe,
WasiView,
};
use rustix::io::Errno;
use rustix::net::sockopt;
use std::io;
use std::net::SocketAddr;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::Interest;
use wasmtime::component::Resource;

impl<T: WasiView> crate::preview2::bindings::sockets::tcp::Host for T {}
Expand Down Expand Up @@ -44,13 +43,7 @@ enum TcpState {
},

/// An outgoing connection is started via `start_connect`.
Connecting,

/// An outgoing connection is ready to be established.
ConnectReady,

/// An outgoing connection was attempted but failed.
ConnectFailed,
Connecting { future: DynFuture<io::Result<()>> },

/// An outgoing connection has been established.
Connected,
Expand Down Expand Up @@ -180,57 +173,31 @@ impl<T: WasiView> crate::preview2::bindings::sockets::tcp::HostTcpSocket for T {
network: Resource<Network>,
remote_address: IpSocketAddress,
) -> SocketResult<()> {
self.ctx().allowed_network_uses.check_allowed_tcp()?;
let table = self.table_mut();
let r = {
let socket = table.get(&this)?;
let network = table.get(&network)?;
let remote_address: SocketAddr = remote_address.into();

match socket.tcp_state {
TcpState::Default => {}
TcpState::Bound
| TcpState::Connected
| TcpState::ConnectFailed
| TcpState::Listening { .. } => return Err(ErrorCode::InvalidState.into()),
TcpState::Connecting
| TcpState::ConnectReady
| TcpState::ListenStarted
| TcpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()),
}

util::validate_unicast(&remote_address)?;
util::validate_remote_address(&remote_address)?;
util::validate_address_family(&remote_address, &socket.inner.family)?;

// Ensure that we're allowed to connect to this address.
network.check_socket_addr(&remote_address, SocketAddrUse::TcpConnect)?;

// Do an OS `connect`. Our socket is non-blocking, so it'll either...
util::tcp_connect(socket.tcp_socket(), &remote_address)
};
// At the moment, there's only one network handle (`instance-network`)
// in existence. All we have to do here is validate that the caller indeed
// has possesion of a valid handle and then we're good to go:
let _network = table.get(&network)?;
let socket = table.get_mut(&this)?;
let remote_address: SocketAddr = remote_address.into();

match r {
// succeed immediately,
Ok(()) => {
let socket = table.get_mut(&this)?;
socket.tcp_state = TcpState::ConnectReady;
return Ok(());
}
// continue in progress,
Err(err) if err == Errno::INPROGRESS => {}
// or fail immediately.
Err(err) => {
return Err(match err {
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument.into(), // See `bind` implementation.
_ => err.into(),
});
}
match socket.tcp_state {
TcpState::Default => {}
TcpState::Connecting { .. } => return Err(ErrorCode::ConcurrencyConflict.into()),
_ => return Err(ErrorCode::InvalidState.into()),
}

let socket = table.get_mut(&this)?;
socket.tcp_state = TcpState::Connecting;
let mut future = socket.inner.connect(&remote_address);

// Attempt to return (validation) errors immediately:
let future = match future.try_resolve() {
Some(Err(e)) => return Err(e.into()),
Some(Ok(())) => DynFuture::ready(Ok(())),
None => future,
};

socket.tcp_state = TcpState::Connecting { future };
Ok(())
}

Expand All @@ -241,41 +208,26 @@ impl<T: WasiView> crate::preview2::bindings::sockets::tcp::HostTcpSocket for T {
let table = self.table_mut();
let socket = table.get_mut(&this)?;

match socket.tcp_state {
TcpState::ConnectReady => {}
TcpState::Connecting => {
// Do a `poll` to test for completion, using a timeout of zero
// to avoid blocking.
match rustix::event::poll(
&mut [rustix::event::PollFd::new(
socket.tcp_socket(),
rustix::event::PollFlags::OUT,
)],
0,
) {
Ok(0) => return Err(ErrorCode::WouldBlock.into()),
Ok(_) => (),
Err(err) => Err(err).unwrap(),
}

// Check whether the connect succeeded.
match sockopt::get_socket_error(socket.tcp_socket()) {
Ok(Ok(())) => {}
Err(err) | Ok(Err(err)) => {
socket.tcp_state = TcpState::ConnectFailed;
return Err(err.into());
}
}
}
_ => return Err(ErrorCode::NotInProgress.into()),
let TcpState::Connecting { future } = &mut socket.tcp_state else {
return Err(ErrorCode::NotInProgress.into());
};

socket.tcp_state = TcpState::Connected;
let (input, output) = socket.as_split();
let input_stream = self.table_mut().push_child(input, &this)?;
let output_stream = self.table_mut().push_child(output, &this)?;
match future.try_resolve() {
Some(Ok(())) => {
socket.tcp_state = TcpState::Connected;

Ok((input_stream, output_stream))
let (input, output) = socket.as_split();
let input_stream = self.table_mut().push_child(input, &this)?;
let output_stream = self.table_mut().push_child(output, &this)?;

Ok((input_stream, output_stream))
}
Some(Err(e)) => {
socket.tcp_state = TcpState::Default;
Err(e.into())
}
None => Err(ErrorCode::WouldBlock.into()),
}
}

fn start_listen(&mut self, this: Resource<TcpSocketWrapper>) -> SocketResult<()> {
Expand All @@ -285,14 +237,12 @@ impl<T: WasiView> crate::preview2::bindings::sockets::tcp::HostTcpSocket for T {

match socket.tcp_state {
TcpState::Bound => {}
TcpState::Default
| TcpState::Connected
| TcpState::ConnectFailed
| TcpState::Listening { .. } => return Err(ErrorCode::InvalidState.into()),
TcpState::ListenStarted
| TcpState::Connecting
| TcpState::ConnectReady
| TcpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()),
TcpState::Default | TcpState::Connected | TcpState::Listening { .. } => {
return Err(ErrorCode::InvalidState.into())
}
TcpState::ListenStarted | TcpState::Connecting { .. } | TcpState::BindStarted => {
return Err(ErrorCode::ConcurrencyConflict.into())
}
}

socket.inner.listen()?;
Expand Down Expand Up @@ -383,9 +333,7 @@ impl<T: WasiView> crate::preview2::bindings::sockets::tcp::HostTcpSocket for T {

match socket.tcp_state {
TcpState::Connected => {}
TcpState::Connecting | TcpState::ConnectReady => {
return Err(ErrorCode::ConcurrencyConflict.into())
}
TcpState::Connecting { .. } => return Err(ErrorCode::ConcurrencyConflict.into()),
_ => return Err(ErrorCode::InvalidState.into()),
}

Expand Down Expand Up @@ -573,9 +521,7 @@ impl<T: WasiView> crate::preview2::bindings::sockets::tcp::HostTcpSocket for T {

match socket.tcp_state {
TcpState::Connected => {}
TcpState::Connecting | TcpState::ConnectReady => {
return Err(ErrorCode::ConcurrencyConflict.into())
}
TcpState::Connecting { .. } => return Err(ErrorCode::ConcurrencyConflict.into()),
_ => return Err(ErrorCode::InvalidState.into()),
}

Expand Down Expand Up @@ -607,23 +553,14 @@ impl Subscribe for TcpSocketWrapper {
| TcpState::BindStarted
| TcpState::Bound
| TcpState::ListenStarted
| TcpState::ConnectReady
| TcpState::Connected
| TcpState::ConnectFailed => {
| TcpState::Connected => {
// No async operation in progress.
}
TcpState::Connecting { future } => future.wait().await,
TcpState::Listening { pending_result } => match pending_result {
Some(_) => {}
None => *pending_result = Some(self.inner.accept().await),
},
TcpState::Connecting => {
// FIXME: Add `Interest::ERROR` when we update to tokio 1.32.
self.inner
.stream
.ready(Interest::READABLE | Interest::WRITABLE)
.await
.unwrap();
}
}
}
}
30 changes: 29 additions & 1 deletion crates/wasi/src/preview2/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::preview2::host::network::util;
use crate::preview2::network::SocketAddressFamily;
use crate::preview2::SocketAddrFamily;
use crate::preview2::{DynFuture, SocketAddrFamily};
use cap_net_ext::Blocking;
use io_lifetimes::AsSocketlike;
use rustix::io::Errno;
Expand Down Expand Up @@ -93,6 +93,34 @@ impl SystemTcpSocket {
)
}

pub fn connect(&mut self, remote_address: &SocketAddr) -> DynFuture<io::Result<()>> {
fn initiate_connect(me: &SystemTcpSocket, remote_address: &SocketAddr) -> io::Result<()> {
util::validate_unicast(&remote_address)?;
util::validate_remote_address(&remote_address)?;
util::validate_address_family(&remote_address, &me.family)?;

util::tcp_connect(&me.stream, remote_address)?;
Ok(())
}

async fn await_connection(stream: Arc<tokio::net::TcpStream>) -> io::Result<()> {
stream.writable().await.unwrap();

// Check whether the connect succeeded.
match sockopt::get_socket_error(&stream) {
Ok(Ok(())) => Ok(()),
Err(err) | Ok(Err(err)) => return Err(err.into()),
}
}

match initiate_connect(self, remote_address) {
Err(e) if Errno::from_io_error(&e) == Some(Errno::INPROGRESS) => {
DynFuture::boxed(await_connection(self.stream.clone()))
}
r => DynFuture::ready(r),
}
}

pub fn listen(&mut self) -> io::Result<()> {
if self.is_listening {
return Err(io::Error::new(
Expand Down

0 comments on commit b8ae1cf

Please sign in to comment.