From b8ae1cf9acfff736520627c2d71d3b4f9fc6a361 Mon Sep 17 00:00:00 2001 From: Dave Bakker Date: Thu, 11 Jan 2024 23:11:41 +0100 Subject: [PATCH] Move `connect` into SystemTcpSocket --- crates/wasi/src/preview2/host/tcp.rs | 161 ++++++++------------------- crates/wasi/src/preview2/tcp.rs | 30 ++++- 2 files changed, 78 insertions(+), 113 deletions(-) diff --git a/crates/wasi/src/preview2/host/tcp.rs b/crates/wasi/src/preview2/host/tcp.rs index 4ba97fae2c14..34992ed8304d 100644 --- a/crates/wasi/src/preview2/host/tcp.rs +++ b/crates/wasi/src/preview2/host/tcp.rs @@ -7,15 +7,14 @@ use crate::preview2::network::SocketAddrUse; use crate::preview2::pipe::{AsyncReadStream, AsyncWriteStream}; use crate::preview2::tcp::SystemTcpSocket; use crate::preview2::{ - InputStream, OutputStream, Pollable, SocketAddrFamily, SocketResult, Subscribe, WasiView, + DynFuture, InputStream, OutputStream, Pollable, SocketAddrFamily, SocketResult, Subscribe, + WasiView, }; use rustix::io::Errno; -use rustix::net::sockopt; use std::io; use std::net::SocketAddr; use std::task::{Context, Poll}; use std::time::Duration; -use tokio::io::Interest; use wasmtime::component::Resource; impl crate::preview2::bindings::sockets::tcp::Host for T {} @@ -44,13 +43,7 @@ enum TcpState { }, /// An outgoing connection is started via `start_connect`. - Connecting, - - /// An outgoing connection is ready to be established. - ConnectReady, - - /// An outgoing connection was attempted but failed. - ConnectFailed, + Connecting { future: DynFuture> }, /// An outgoing connection has been established. Connected, @@ -180,57 +173,31 @@ impl crate::preview2::bindings::sockets::tcp::HostTcpSocket for T { network: Resource, remote_address: IpSocketAddress, ) -> SocketResult<()> { - self.ctx().allowed_network_uses.check_allowed_tcp()?; let table = self.table_mut(); - let r = { - let socket = table.get(&this)?; - let network = table.get(&network)?; - let remote_address: SocketAddr = remote_address.into(); - - match socket.tcp_state { - TcpState::Default => {} - TcpState::Bound - | TcpState::Connected - | TcpState::ConnectFailed - | TcpState::Listening { .. } => return Err(ErrorCode::InvalidState.into()), - TcpState::Connecting - | TcpState::ConnectReady - | TcpState::ListenStarted - | TcpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()), - } - - util::validate_unicast(&remote_address)?; - util::validate_remote_address(&remote_address)?; - util::validate_address_family(&remote_address, &socket.inner.family)?; - // Ensure that we're allowed to connect to this address. - network.check_socket_addr(&remote_address, SocketAddrUse::TcpConnect)?; - - // Do an OS `connect`. Our socket is non-blocking, so it'll either... - util::tcp_connect(socket.tcp_socket(), &remote_address) - }; + // At the moment, there's only one network handle (`instance-network`) + // in existence. All we have to do here is validate that the caller indeed + // has possesion of a valid handle and then we're good to go: + let _network = table.get(&network)?; + let socket = table.get_mut(&this)?; + let remote_address: SocketAddr = remote_address.into(); - match r { - // succeed immediately, - Ok(()) => { - let socket = table.get_mut(&this)?; - socket.tcp_state = TcpState::ConnectReady; - return Ok(()); - } - // continue in progress, - Err(err) if err == Errno::INPROGRESS => {} - // or fail immediately. - Err(err) => { - return Err(match err { - Errno::AFNOSUPPORT => ErrorCode::InvalidArgument.into(), // See `bind` implementation. - _ => err.into(), - }); - } + match socket.tcp_state { + TcpState::Default => {} + TcpState::Connecting { .. } => return Err(ErrorCode::ConcurrencyConflict.into()), + _ => return Err(ErrorCode::InvalidState.into()), } - let socket = table.get_mut(&this)?; - socket.tcp_state = TcpState::Connecting; + let mut future = socket.inner.connect(&remote_address); + // Attempt to return (validation) errors immediately: + let future = match future.try_resolve() { + Some(Err(e)) => return Err(e.into()), + Some(Ok(())) => DynFuture::ready(Ok(())), + None => future, + }; + + socket.tcp_state = TcpState::Connecting { future }; Ok(()) } @@ -241,41 +208,26 @@ impl crate::preview2::bindings::sockets::tcp::HostTcpSocket for T { let table = self.table_mut(); let socket = table.get_mut(&this)?; - match socket.tcp_state { - TcpState::ConnectReady => {} - TcpState::Connecting => { - // Do a `poll` to test for completion, using a timeout of zero - // to avoid blocking. - match rustix::event::poll( - &mut [rustix::event::PollFd::new( - socket.tcp_socket(), - rustix::event::PollFlags::OUT, - )], - 0, - ) { - Ok(0) => return Err(ErrorCode::WouldBlock.into()), - Ok(_) => (), - Err(err) => Err(err).unwrap(), - } - - // Check whether the connect succeeded. - match sockopt::get_socket_error(socket.tcp_socket()) { - Ok(Ok(())) => {} - Err(err) | Ok(Err(err)) => { - socket.tcp_state = TcpState::ConnectFailed; - return Err(err.into()); - } - } - } - _ => return Err(ErrorCode::NotInProgress.into()), + let TcpState::Connecting { future } = &mut socket.tcp_state else { + return Err(ErrorCode::NotInProgress.into()); }; - socket.tcp_state = TcpState::Connected; - let (input, output) = socket.as_split(); - let input_stream = self.table_mut().push_child(input, &this)?; - let output_stream = self.table_mut().push_child(output, &this)?; + match future.try_resolve() { + Some(Ok(())) => { + socket.tcp_state = TcpState::Connected; - Ok((input_stream, output_stream)) + let (input, output) = socket.as_split(); + let input_stream = self.table_mut().push_child(input, &this)?; + let output_stream = self.table_mut().push_child(output, &this)?; + + Ok((input_stream, output_stream)) + } + Some(Err(e)) => { + socket.tcp_state = TcpState::Default; + Err(e.into()) + } + None => Err(ErrorCode::WouldBlock.into()), + } } fn start_listen(&mut self, this: Resource) -> SocketResult<()> { @@ -285,14 +237,12 @@ impl crate::preview2::bindings::sockets::tcp::HostTcpSocket for T { match socket.tcp_state { TcpState::Bound => {} - TcpState::Default - | TcpState::Connected - | TcpState::ConnectFailed - | TcpState::Listening { .. } => return Err(ErrorCode::InvalidState.into()), - TcpState::ListenStarted - | TcpState::Connecting - | TcpState::ConnectReady - | TcpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()), + TcpState::Default | TcpState::Connected | TcpState::Listening { .. } => { + return Err(ErrorCode::InvalidState.into()) + } + TcpState::ListenStarted | TcpState::Connecting { .. } | TcpState::BindStarted => { + return Err(ErrorCode::ConcurrencyConflict.into()) + } } socket.inner.listen()?; @@ -383,9 +333,7 @@ impl crate::preview2::bindings::sockets::tcp::HostTcpSocket for T { match socket.tcp_state { TcpState::Connected => {} - TcpState::Connecting | TcpState::ConnectReady => { - return Err(ErrorCode::ConcurrencyConflict.into()) - } + TcpState::Connecting { .. } => return Err(ErrorCode::ConcurrencyConflict.into()), _ => return Err(ErrorCode::InvalidState.into()), } @@ -573,9 +521,7 @@ impl crate::preview2::bindings::sockets::tcp::HostTcpSocket for T { match socket.tcp_state { TcpState::Connected => {} - TcpState::Connecting | TcpState::ConnectReady => { - return Err(ErrorCode::ConcurrencyConflict.into()) - } + TcpState::Connecting { .. } => return Err(ErrorCode::ConcurrencyConflict.into()), _ => return Err(ErrorCode::InvalidState.into()), } @@ -607,23 +553,14 @@ impl Subscribe for TcpSocketWrapper { | TcpState::BindStarted | TcpState::Bound | TcpState::ListenStarted - | TcpState::ConnectReady - | TcpState::Connected - | TcpState::ConnectFailed => { + | TcpState::Connected => { // No async operation in progress. } + TcpState::Connecting { future } => future.wait().await, TcpState::Listening { pending_result } => match pending_result { Some(_) => {} None => *pending_result = Some(self.inner.accept().await), }, - TcpState::Connecting => { - // FIXME: Add `Interest::ERROR` when we update to tokio 1.32. - self.inner - .stream - .ready(Interest::READABLE | Interest::WRITABLE) - .await - .unwrap(); - } } } } diff --git a/crates/wasi/src/preview2/tcp.rs b/crates/wasi/src/preview2/tcp.rs index 202731635359..a04a16166afb 100644 --- a/crates/wasi/src/preview2/tcp.rs +++ b/crates/wasi/src/preview2/tcp.rs @@ -1,6 +1,6 @@ use crate::preview2::host::network::util; use crate::preview2::network::SocketAddressFamily; -use crate::preview2::SocketAddrFamily; +use crate::preview2::{DynFuture, SocketAddrFamily}; use cap_net_ext::Blocking; use io_lifetimes::AsSocketlike; use rustix::io::Errno; @@ -93,6 +93,34 @@ impl SystemTcpSocket { ) } + pub fn connect(&mut self, remote_address: &SocketAddr) -> DynFuture> { + fn initiate_connect(me: &SystemTcpSocket, remote_address: &SocketAddr) -> io::Result<()> { + util::validate_unicast(&remote_address)?; + util::validate_remote_address(&remote_address)?; + util::validate_address_family(&remote_address, &me.family)?; + + util::tcp_connect(&me.stream, remote_address)?; + Ok(()) + } + + async fn await_connection(stream: Arc) -> io::Result<()> { + stream.writable().await.unwrap(); + + // Check whether the connect succeeded. + match sockopt::get_socket_error(&stream) { + Ok(Ok(())) => Ok(()), + Err(err) | Ok(Err(err)) => return Err(err.into()), + } + } + + match initiate_connect(self, remote_address) { + Err(e) if Errno::from_io_error(&e) == Some(Errno::INPROGRESS) => { + DynFuture::boxed(await_connection(self.stream.clone())) + } + r => DynFuture::ready(r), + } + } + pub fn listen(&mut self) -> io::Result<()> { if self.is_listening { return Err(io::Error::new(