From 87001bf43032ba0909adf1843cd1aec48cd1b633 Mon Sep 17 00:00:00 2001 From: badeend Date: Wed, 23 Oct 2024 16:15:29 +0200 Subject: [PATCH] WIP --- Cargo.lock | 3 + Cargo.toml | 3 + crates/cli-flags/src/lib.rs | 2 + .../bin/preview2_tls_sample_application.rs | 45 +++ crates/test-programs/src/lib.rs | 2 +- crates/test-programs/src/sockets.rs | 17 ++ crates/wasi-http/Cargo.toml | 6 +- crates/wasi-http/wit/deps/sockets/tls.wit | 29 ++ crates/wasi-http/wit/deps/sockets/world.wit | 2 + crates/wasi/Cargo.toml | 3 + crates/wasi/src/bindings.rs | 4 + crates/wasi/src/host/mod.rs | 1 + crates/wasi/src/host/tls.rs | 278 ++++++++++++++++++ crates/wasi/src/lib.rs | 2 + crates/wasi/src/stream.rs | 225 ++++++++++++++ crates/wasi/tests/all/async_.rs | 17 +- crates/wasi/tests/all/sync.rs | 15 +- crates/wasi/wit/deps/sockets/tls.wit | 29 ++ crates/wasi/wit/deps/sockets/world.wit | 2 + src/common.rs | 1 + 20 files changed, 676 insertions(+), 10 deletions(-) create mode 100644 crates/test-programs/src/bin/preview2_tls_sample_application.rs create mode 100644 crates/wasi-http/wit/deps/sockets/tls.wit create mode 100644 crates/wasi/src/host/tls.rs create mode 100644 crates/wasi/wit/deps/sockets/tls.wit diff --git a/Cargo.lock b/Cargo.lock index aa4f0c45a885..341a5b6aeb8b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4276,16 +4276,19 @@ dependencies = [ "io-extras", "io-lifetimes", "rustix", + "rustls 0.22.4", "system-interface", "tempfile", "test-log", "test-programs-artifacts", "thiserror", "tokio", + "tokio-rustls", "tracing", "tracing-subscriber", "url", "wasmtime", + "webpki-roots", "wiggle", "windows-sys 0.59.0", ] diff --git a/Cargo.toml b/Cargo.toml index aa5fdee31a98..4086eaced364 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -343,6 +343,9 @@ criterion = { version = "0.5.0", default-features = false, features = ["html_rep rustc-hash = "2.0.0" libtest-mimic = "0.7.0" semver = { version = "1.0.17", default-features = false } +tokio-rustls = "0.25.0" +rustls = "0.22.0" +webpki-roots = "0.26.0" # ============================================================================= # diff --git a/crates/cli-flags/src/lib.rs b/crates/cli-flags/src/lib.rs index 592ebd1f40e2..529f6e48f5ba 100644 --- a/crates/cli-flags/src/lib.rs +++ b/crates/cli-flags/src/lib.rs @@ -383,6 +383,8 @@ wasmtime_option_group! { pub udp: Option, /// Enable WASI APIs marked as: @unstable(feature = network-error-code) pub network_error_code: Option, + /// Enable WASI APIs marked as: @unstable(feature = tls) + pub tls: Option, /// Allows imports from the `wasi_unstable` core wasm module. pub preview0: Option, /// Inherit all environment variables from the parent process. diff --git a/crates/test-programs/src/bin/preview2_tls_sample_application.rs b/crates/test-programs/src/bin/preview2_tls_sample_application.rs new file mode 100644 index 000000000000..158c3d4ceaa6 --- /dev/null +++ b/crates/test-programs/src/bin/preview2_tls_sample_application.rs @@ -0,0 +1,45 @@ +use core::str; + +use test_programs::wasi::sockets::network::{IpSocketAddress, Network}; +use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket}; +use test_programs::wasi::sockets::tls; + +fn test_tls_sample_application() { + const PORT: u16 = 443; + const DOMAIN: &'static str = "example.com"; + + let request = format!("GET / HTTP/1.1\r\nHost: {DOMAIN}\r\n\r\n"); + + let net = Network::default(); + + let Some(ip) = net + .permissive_blocking_resolve_addresses(DOMAIN) + .unwrap() + .first() + .map(|a| a.to_owned()) + else { + // eprintln!("DNS lookup failed."); // TODO + panic!("DNS lookup failed."); + return; + }; + + let socket = TcpSocket::new(ip.family()).unwrap(); + let (tcp_input, tcp_output) = socket + .blocking_connect(&net, IpSocketAddress::new(ip, PORT)) + .unwrap(); + + let (_client, tls_input, tls_output) = tls::ClientHandshake::new(DOMAIN, tcp_input, tcp_output) + .blocking_finish() + .unwrap(); + + tls_output.blocking_write_util(request.as_bytes()).unwrap(); + socket.shutdown(ShutdownType::Send).unwrap(); + let response = tls_input.blocking_read_to_end().unwrap(); + let response = String::from_utf8(response).unwrap(); + + assert!(response.contains("HTTP/1.1 200 OK")); +} + +fn main() { + test_tls_sample_application(); +} diff --git a/crates/test-programs/src/lib.rs b/crates/test-programs/src/lib.rs index bdd9f1cba609..e1d830d14fe5 100644 --- a/crates/test-programs/src/lib.rs +++ b/crates/test-programs/src/lib.rs @@ -20,7 +20,7 @@ wit_bindgen::generate!({ "../wasi-keyvalue/wit", ], world: "wasmtime:test/test", - features: ["cli-exit-with-code"], + features: ["cli-exit-with-code", "tls"], generate_all, }); diff --git a/crates/test-programs/src/sockets.rs b/crates/test-programs/src/sockets.rs index 0fcccaaab4af..7a6c45c3fa6d 100644 --- a/crates/test-programs/src/sockets.rs +++ b/crates/test-programs/src/sockets.rs @@ -13,6 +13,7 @@ use crate::wasi::sockets::udp::{ IncomingDatagram, IncomingDatagramStream, OutgoingDatagram, OutgoingDatagramStream, UdpSocket, }; use crate::wasi::sockets::{tcp_create_socket, udp_create_socket}; +use crate::wasi::sockets::tls as tls; use std::ops::Range; const TIMEOUT_NS: u64 = 1_000_000_000; @@ -265,6 +266,22 @@ impl IncomingDatagramStream { } } +impl tls::ClientHandshake { + pub fn blocking_finish(self) -> Result<(tls::ClientConnection, InputStream, OutputStream), ()> { + let future = tls::ClientHandshake::finish(self); + let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS); + let pollable = future.subscribe(); + + loop { + match future.get() { + None => pollable.block_until(&timeout).expect("timed out"), + Some(Ok(r)) => return r, + Some(Err(_)) => unreachable!(), + } + } + } +} + impl IpAddress { pub const IPV4_BROADCAST: IpAddress = IpAddress::Ipv4((255, 255, 255, 255)); diff --git a/crates/wasi-http/Cargo.toml b/crates/wasi-http/Cargo.toml index 3339fc1b5743..c2ae80c6fd17 100644 --- a/crates/wasi-http/Cargo.toml +++ b/crates/wasi-http/Cargo.toml @@ -31,9 +31,9 @@ wasmtime = { workspace = true, features = ['component-model'] } # The `ring` crate, used to implement TLS, does not build on riscv64 or s390x [target.'cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))'.dependencies] -tokio-rustls = { version = "0.25.0" } -rustls = { version = "0.22.0" } -webpki-roots = { version = "0.26.0" } +tokio-rustls = { workspace = true } +rustls = { workspace = true } +webpki-roots = { workspace = true } [dev-dependencies] test-programs-artifacts = { workspace = true } diff --git a/crates/wasi-http/wit/deps/sockets/tls.wit b/crates/wasi-http/wit/deps/sockets/tls.wit new file mode 100644 index 000000000000..886a7f27942e --- /dev/null +++ b/crates/wasi-http/wit/deps/sockets/tls.wit @@ -0,0 +1,29 @@ +@unstable(feature = tls) +interface tls { + @unstable(feature = tls) + use wasi:io/streams@0.2.2.{input-stream, output-stream}; + @unstable(feature = tls) + use wasi:io/poll@0.2.2.{pollable}; + + @unstable(feature = tls) + resource client-handshake { + @unstable(feature = tls) + constructor(server-name: string, input: input-stream, output: output-stream); + + @unstable(feature = tls) + finish: static func(this: client-handshake) -> future-client-streams; + } + + @unstable(feature = tls) + resource client-connection { + } + + @unstable(feature = tls) + resource future-client-streams { + @unstable(feature = tls) + subscribe: func() -> pollable; + + @unstable(feature = tls) + get: func() -> option>>>; + } +} diff --git a/crates/wasi-http/wit/deps/sockets/world.wit b/crates/wasi-http/wit/deps/sockets/world.wit index 6e349c756b5e..bfcfab4c2743 100644 --- a/crates/wasi-http/wit/deps/sockets/world.wit +++ b/crates/wasi-http/wit/deps/sockets/world.wit @@ -16,4 +16,6 @@ world imports { import tcp-create-socket; @since(version = 0.2.0) import ip-name-lookup; + @unstable(feature = tls) + import tls; } diff --git a/crates/wasi/Cargo.toml b/crates/wasi/Cargo.toml index db02061ce315..57bfcba7eaf7 100644 --- a/crates/wasi/Cargo.toml +++ b/crates/wasi/Cargo.toml @@ -35,6 +35,9 @@ async-trait = { workspace = true } system-interface = { workspace = true} futures = { workspace = true } url = { workspace = true } +tokio-rustls = { workspace = true } +rustls = { workspace = true } +webpki-roots = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["time", "sync", "io-std", "io-util", "rt", "rt-multi-thread", "net", "macros", "fs"] } diff --git a/crates/wasi/src/bindings.rs b/crates/wasi/src/bindings.rs index 72c2c3ff9943..754f6d835a33 100644 --- a/crates/wasi/src/bindings.rs +++ b/crates/wasi/src/bindings.rs @@ -164,6 +164,7 @@ pub mod sync { "wasi:io/error": crate::bindings::io::error, "wasi:filesystem/preopens": crate::bindings::filesystem::preopens, "wasi:sockets/network": crate::bindings::sockets::network, + "wasi:sockets/tls": crate::bindings::sockets::tls, // Configure the resource types of the bound interfaces here // to be the same as the async versions of the resources, that @@ -406,6 +407,9 @@ mod async_io { "wasi:sockets/udp/incoming-datagram-stream": crate::udp::IncomingDatagramStream, "wasi:sockets/udp/outgoing-datagram-stream": crate::udp::OutgoingDatagramStream, "wasi:sockets/ip-name-lookup/resolve-address-stream": crate::ip_name_lookup::ResolveAddressStream, + "wasi:sockets/tls/client-connection": crate::host::tls::ClientConnection, + "wasi:sockets/tls/client-handshake": crate::host::tls::ClientHandshake, + "wasi:sockets/tls/future-client-streams": crate::host::tls::FutureClientStreams, "wasi:filesystem/types/directory-entry-stream": crate::filesystem::ReaddirIterator, "wasi:filesystem/types/descriptor": crate::filesystem::Descriptor, "wasi:io/streams/input-stream": crate::stream::InputStream, diff --git a/crates/wasi/src/host/mod.rs b/crates/wasi/src/host/mod.rs index 7aa4a8741a9b..9ed0d2649f0c 100644 --- a/crates/wasi/src/host/mod.rs +++ b/crates/wasi/src/host/mod.rs @@ -8,5 +8,6 @@ pub(crate) mod network; mod random; mod tcp; mod tcp_create_socket; +pub(crate) mod tls; mod udp; mod udp_create_socket; diff --git a/crates/wasi/src/host/tls.rs b/crates/wasi/src/host/tls.rs new file mode 100644 index 000000000000..f0367a59f7e9 --- /dev/null +++ b/crates/wasi/src/host/tls.rs @@ -0,0 +1,278 @@ +use crate::bindings::io::streams::{InputStream, OutputStream}; +use crate::pipe::{AsyncReadStream, AsyncWriteStream}; +use crate::stream::DuplexAsyncReadWrite; +use crate::{ + HostInputStream, HostOutputStream, Pollable, StreamError, Subscribe, WasiImpl, WasiView, +}; +use rustls::pki_types::ServerName; +use std::future::Future; +use wasmtime::component::{Resource, ResourceTable}; + +impl crate::bindings::sockets::tls::Host for WasiImpl where T: WasiView {} + +pub struct ClientHandshake { + tombstones: Tombstones, + server_name: String, + input: InputStream, + output: OutputStream, +} + +pub struct FutureClientStreams { + tombstones: Option, + connect: PollableFuture>>, +} + +#[async_trait::async_trait] +impl Subscribe for FutureClientStreams { + async fn ready(&mut self) { + self.connect.ready().await + } +} + +pub struct ClientConnection { + tombstones: Tombstones, +} + +#[async_trait::async_trait] +impl crate::bindings::sockets::tls::HostClientConnection for WasiImpl +where + T: WasiView, +{ + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table().delete(this)?.tombstones.delete(self.table())?; + Ok(()) + } +} + +#[async_trait::async_trait] +impl crate::bindings::sockets::tls::HostClientHandshake for WasiImpl +where + T: WasiView, +{ + fn new( + &mut self, + server_name: String, + input: Resource, + output: Resource, + ) -> wasmtime::Result> { + // Take the provided streams out of the table, but keep their indexes + // alive to prevent the child resources outliving their parents: + let input_stream = std::mem::replace(self.table().get_mut(&input)?, Box::new(Tombstone)); + let output_stream = std::mem::replace(self.table().get_mut(&output)?, Box::new(Tombstone)); + + Ok(self.table().push(ClientHandshake { + server_name, + input: input_stream, + output: output_stream, + tombstones: Tombstones { + input: input, + output: output, + }, + })?) + } + + fn finish( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + let handshake = self.table().delete(this)?; + + let server_name = handshake.server_name; + let inner_stream = DuplexAsyncReadWrite::new(handshake.input, handshake.output); + + let root_cert_store = rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.into(), + }; + let config = std::sync::Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(), + ); + + let config_clone = config.clone(); + + Ok(self.table().push(FutureClientStreams { + tombstones: Some(handshake.tombstones), + connect: PollableFuture::new(crate::runtime::spawn(async move { + let connector = tokio_rustls::TlsConnector::from(config_clone); + let domain = ServerName::try_from(server_name)?; + let stream = connector.connect(domain, inner_stream).await?; + Ok(stream) + })), + })?) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table().delete(this)?.tombstones.delete(self.table())?; + Ok(()) + } +} + +#[async_trait::async_trait] +impl crate::bindings::sockets::tls::HostFutureClientStreams for WasiImpl +where + T: WasiView, +{ + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + crate::poll::subscribe(self.table(), this) + } + + fn get( + &mut self, + this: Resource, + ) -> wasmtime::Result< + Option< + Result< + Result< + ( + Resource, + Resource, + Resource, + ), + (), + >, + (), + >, + >, + > { + let this = self.table().get_mut(&this)?; + + let tls_stream = match this.connect.take_ready() { + TakeReady::Ready(Ok(tls_stream)) => tls_stream, + TakeReady::Ready(Err(e)) => return Ok(Some(Ok(Err(())))), // TODO: don't throw away error + TakeReady::Pending => return Ok(None), + TakeReady::Consumed => return Ok(Some(Err(()))), + }; + + let client = ClientConnection { + tombstones: this.tombstones.take().unwrap(), + }; + + let (rx, tx) = tokio::io::split(tls_stream); + + let input: InputStream = Box::new(AsyncReadStream::new(rx)); + let output: OutputStream = Box::new(AsyncWriteStream::new(64 * 1024, tx)); // TODO: buffer size + + let client = self.table().push(client)?; + let input = self.table().push_child(input, &client)?; + let output = self.table().push_child(output, &client)?; + + Ok(Some(Ok(Ok((client, input, output))))) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + let mut future = self.table().delete(this)?; + + match future.tombstones.take() { + Some(c) => c.delete(self.table())?, + None => {} + } + Ok(()) + } +} + +/// Placeholder stream type that is substituted in place to keep the +/// parent<->child resource lifetime restrictions in tact. +struct Tombstone; + +#[async_trait::async_trait] +impl HostInputStream for Tombstone { + fn read(&mut self, _size: usize) -> Result { + Err(StreamError::trap("stream has been consumed")) + } +} + +#[async_trait::async_trait] +impl HostOutputStream for Tombstone { + fn write(&mut self, _bytes: bytes::Bytes) -> Result<(), StreamError> { + Err(StreamError::trap("stream has been consumed")) + } + + fn flush(&mut self) -> Result<(), StreamError> { + Err(StreamError::trap("stream has been consumed")) + } + + fn check_write(&mut self) -> Result { + Err(StreamError::trap("stream has been consumed")) + } +} + +#[async_trait::async_trait] +impl Subscribe for Tombstone { + async fn ready(&mut self) {} +} + +struct Tombstones { + input: Resource, + output: Resource, +} +impl Tombstones { + fn delete(self, table: &mut ResourceTable) -> wasmtime::Result<()> { + table.delete(self.input)?; + table.delete(self.output)?; + Ok(()) + } +} + +enum PollableFuture { + Pending(std::pin::Pin + Send>>), + Ready(T), + Consumed, +} +enum TakeReady { + Ready(T), + Pending, + Consumed, +} +impl PollableFuture { + pub fn new(fut: F) -> Self + where + F: Future + Send + 'static, + { + Self::Pending(Box::pin(fut)) + } + + pub fn take_ready(&mut self) -> TakeReady { + match std::mem::replace(self, Self::Consumed) { + Self::Ready(value) => TakeReady::Ready(value), + Self::Pending(fut) => { + *self = Self::Pending(fut); + TakeReady::Pending + } + Self::Consumed => TakeReady::Consumed, + } + } +} +#[async_trait::async_trait] +impl Subscribe for PollableFuture { + async fn ready(&mut self) { + match self { + Self::Pending(fut) => *self = Self::Ready(fut.await), + Self::Ready(_) | Self::Consumed => {} + } + } +} +impl Future for PollableFuture { + type Output = T; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let _self = self.get_mut(); + match std::mem::replace(_self, Self::Consumed) { + Self::Pending(mut fut) => match std::pin::pin!(&mut fut).poll(cx) { + std::task::Poll::Ready(value) => std::task::Poll::Ready(value), + std::task::Poll::Pending => { + *_self = Self::Pending(fut); + std::task::Poll::Pending + } + }, + Self::Ready(value) => std::task::Poll::Ready(value), + Self::Consumed => panic!("can't await consumed pollable"), + } + } +} diff --git a/crates/wasi/src/lib.rs b/crates/wasi/src/lib.rs index ddb389964316..38fa779ed7d3 100644 --- a/crates/wasi/src/lib.rs +++ b/crates/wasi/src/lib.rs @@ -325,6 +325,7 @@ pub fn add_to_linker_with_options_async( crate::bindings::sockets::instance_network::add_to_linker_get_host(l, closure)?; crate::bindings::sockets::network::add_to_linker_get_host(l, &options.into(), closure)?; crate::bindings::sockets::ip_name_lookup::add_to_linker_get_host(l, closure)?; + crate::bindings::sockets::tls::add_to_linker_get_host(l, &options.into(), closure)?; Ok(()) } @@ -424,6 +425,7 @@ pub fn add_to_linker_with_options_sync( crate::bindings::sockets::instance_network::add_to_linker_get_host(l, closure)?; crate::bindings::sockets::network::add_to_linker_get_host(l, &options.into(), closure)?; crate::bindings::sockets::ip_name_lookup::add_to_linker_get_host(l, closure)?; + crate::bindings::sync::sockets::tls::add_to_linker_get_host(l, &options.into(), closure)?; Ok(()) } diff --git a/crates/wasi/src/stream.rs b/crates/wasi/src/stream.rs index 251133cdac66..f87392fdbcd0 100644 --- a/crates/wasi/src/stream.rs +++ b/crates/wasi/src/stream.rs @@ -1,3 +1,9 @@ +use std::{ + future::Future, + pin::{pin, Pin}, + task::{ready, Poll}, +}; + use crate::poll::Subscribe; use anyhow::Result; use bytes::Bytes; @@ -263,3 +269,222 @@ impl Subscribe for Box { pub type InputStream = Box; pub type OutputStream = Box; + +pub struct InputAsyncRead { + state: AsyncReadState, +} +enum AsyncReadState { + Ready(InputStream), + Pending(Pin + Send>>), + Closed, +} +impl InputAsyncRead { + pub fn new(input: InputStream) -> Self { + InputAsyncRead { + state: AsyncReadState::Ready(input), + } + } +} +impl tokio::io::AsyncRead for InputAsyncRead { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + loop { + let stream = match &mut self.state { + AsyncReadState::Ready(stream) => stream, + AsyncReadState::Pending(fut) => { + let stream = ready!(fut.as_mut().poll(cx)); + self.state = AsyncReadState::Ready(stream); + if let AsyncReadState::Ready(stream) = &mut self.state { + stream + } else { + unreachable!() + } + } + AsyncReadState::Closed => return Poll::Ready(Ok(())), + }; + + // FYI, POSIX and the `AsyncRead` contract defines that a 0-byte + // result indicates the end of the stream. In WASI, a 0-byte result + // indicates that the stream isn't ready I/O right now, i.e. EWOULDBLOCK. + match stream.read(buf.remaining()) { + Ok(bytes) if bytes.is_empty() => { + let AsyncReadState::Ready(mut stream) = + std::mem::replace(&mut self.state, AsyncReadState::Closed) + else { + unreachable!() + }; + + self.state = AsyncReadState::Pending(Box::pin(async move { + stream.ready().await; + stream + })); + + // Continue looping + } + Ok(bytes) => { + buf.put_slice(&bytes); + return Poll::Ready(Ok(())); + } + Err(StreamError::Closed) => { + self.state = AsyncReadState::Closed; + return Poll::Ready(Ok(())); + } + Err(e) => { + self.state = AsyncReadState::Closed; + return Poll::Ready(Err(std::io::Error::other(e))); + } + } + } + } +} + +pub struct OutputAsyncWrite { + state: AsyncWriteState, +} +enum AsyncWriteState { + Ready(OutputStream), + Pending(Pin + Send>>), + Closed, +} +impl OutputAsyncWrite { + pub fn new(output: OutputStream) -> Self { + OutputAsyncWrite { + state: AsyncWriteState::Ready(output), + } + } +} +impl tokio::io::AsyncWrite for OutputAsyncWrite { + // This `poll_write` implementation interprets a write of 0 bytes as an I/O + // readiness check. The `poll_flush` implementation below depends on this behavior. + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + let stream = match &mut self.state { + AsyncWriteState::Ready(stream) => stream, + AsyncWriteState::Pending(fut) => { + let stream = ready!(fut.as_mut().poll(cx)); + self.state = AsyncWriteState::Ready(stream); + if let AsyncWriteState::Ready(stream) = &mut self.state { + stream + } else { + unreachable!() + } + } + AsyncWriteState::Closed => return Poll::Ready(Ok(0)), + }; + + // FYI, the `AsyncWrite` contract defines that a 0-byte result + // indicates the end of the stream. In WASI, a 0-byte result indicates + // that the stream isn't ready I/O right now, i.e. EWOULDBLOCK. + match stream.check_write() { + Ok(0) => { + let AsyncWriteState::Ready(mut stream) = + std::mem::replace(&mut self.state, AsyncWriteState::Closed) + else { + unreachable!() + }; + + self.state = AsyncWriteState::Pending(Box::pin(async move { + stream.ready().await; + stream + })); + + // Continue looping + } + Ok(n) => { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + let size = n.min(buf.len()); + return match stream.write(Bytes::copy_from_slice(&buf[..size])) { + Ok(()) => Poll::Ready(Ok(size)), + Err(StreamError::Closed) => { + self.state = AsyncWriteState::Closed; + Poll::Ready(Ok(0)) + } + Err(e) => { + self.state = AsyncWriteState::Closed; + Poll::Ready(Err(std::io::Error::other(e))) + } + }; + } + Err(StreamError::Closed) => { + self.state = AsyncWriteState::Closed; + return Poll::Ready(Ok(0)); + } + Err(e) => { + self.state = AsyncWriteState::Closed; + return Poll::Ready(Err(std::io::Error::other(e))); + } + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_write(cx, &[]).map_ok(|_| ()) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let result = ready!(self.as_mut().poll_flush(cx)); + self.state = AsyncWriteState::Closed; + Poll::Ready(result) + } +} + +pub struct DuplexAsyncReadWrite { + read: InputAsyncRead, + write: OutputAsyncWrite, +} +impl DuplexAsyncReadWrite { + pub fn new(input: InputStream, output: OutputStream) -> Self { + DuplexAsyncReadWrite { + read: InputAsyncRead::new(input), + write: OutputAsyncWrite::new(output), + } + } +} +impl tokio::io::AsyncRead for DuplexAsyncReadWrite { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + pin!(&mut self.read).poll_read(cx, buf) + } +} +impl tokio::io::AsyncWrite for DuplexAsyncReadWrite { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + pin!(&mut self.write).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + pin!(&mut self.write).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + pin!(&mut self.write).poll_shutdown(cx) + } +} diff --git a/crates/wasi/tests/all/async_.rs b/crates/wasi/tests/all/async_.rs index 7d25423726b2..cacb64379f54 100644 --- a/crates/wasi/tests/all/async_.rs +++ b/crates/wasi/tests/all/async_.rs @@ -1,17 +1,20 @@ use super::*; use std::path::Path; use test_programs_artifacts::*; -use wasmtime_wasi::add_to_linker_async; -use wasmtime_wasi::bindings::Command; +use wasmtime_wasi::add_to_linker_with_options_async; +use wasmtime_wasi::bindings::{Command, LinkOptions}; async fn run(path: &str, inherit_stdio: bool) -> Result<()> { + run_with_options(path, inherit_stdio, &LinkOptions::default()).await +} +async fn run_with_options(path: &str, inherit_stdio: bool, options: &LinkOptions) -> Result<()> { let path = Path::new(path); let name = path.file_stem().unwrap().to_str().unwrap(); let engine = test_programs_artifacts::engine(|config| { config.async_support(true); }); let mut linker = Linker::new(&engine); - add_to_linker_async(&mut linker)?; + add_to_linker_with_options_async(&mut linker, options)?; let (mut store, _td) = store(&engine, name, |builder| { if inherit_stdio { @@ -336,6 +339,14 @@ async fn preview2_tcp_connect() { run(PREVIEW2_TCP_CONNECT_COMPONENT, false).await.unwrap() } #[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn preview2_tls_sample_application() { + let mut options = LinkOptions::default(); + options.tls(true); + run_with_options(PREVIEW2_TLS_SAMPLE_APPLICATION_COMPONENT, false, &options) + .await + .unwrap() +} +#[test_log::test(tokio::test(flavor = "multi_thread"))] async fn preview2_udp_sockopts() { run(PREVIEW2_UDP_SOCKOPTS_COMPONENT, false).await.unwrap() } diff --git a/crates/wasi/tests/all/sync.rs b/crates/wasi/tests/all/sync.rs index 5629d5d3ffd8..846072e69e24 100644 --- a/crates/wasi/tests/all/sync.rs +++ b/crates/wasi/tests/all/sync.rs @@ -1,15 +1,18 @@ use super::*; use std::path::Path; use test_programs_artifacts::*; -use wasmtime_wasi::add_to_linker_sync; -use wasmtime_wasi::bindings::sync::Command; +use wasmtime_wasi::add_to_linker_with_options_sync; +use wasmtime_wasi::bindings::sync::{Command, LinkOptions}; fn run(path: &str, inherit_stdio: bool) -> Result<()> { + run_with_options(path, inherit_stdio, &LinkOptions::default()) +} +fn run_with_options(path: &str, inherit_stdio: bool, options: &LinkOptions) -> Result<()> { let path = Path::new(path); let name = path.file_stem().unwrap().to_str().unwrap(); let engine = test_programs_artifacts::engine(|_| {}); let mut linker = Linker::new(&engine); - add_to_linker_sync(&mut linker)?; + add_to_linker_with_options_sync(&mut linker, options)?; let component = Component::from_file(&engine, path)?; @@ -282,6 +285,12 @@ fn preview2_tcp_connect() { run(PREVIEW2_TCP_CONNECT_COMPONENT, false).unwrap() } #[test_log::test] +fn preview2_tls_sample_application() { + let mut options = LinkOptions::default(); + options.tls(true); + run_with_options(PREVIEW2_TLS_SAMPLE_APPLICATION_COMPONENT, false, &options).unwrap() +} +#[test_log::test] fn preview2_udp_sockopts() { run(PREVIEW2_UDP_SOCKOPTS_COMPONENT, false).unwrap() } diff --git a/crates/wasi/wit/deps/sockets/tls.wit b/crates/wasi/wit/deps/sockets/tls.wit new file mode 100644 index 000000000000..886a7f27942e --- /dev/null +++ b/crates/wasi/wit/deps/sockets/tls.wit @@ -0,0 +1,29 @@ +@unstable(feature = tls) +interface tls { + @unstable(feature = tls) + use wasi:io/streams@0.2.2.{input-stream, output-stream}; + @unstable(feature = tls) + use wasi:io/poll@0.2.2.{pollable}; + + @unstable(feature = tls) + resource client-handshake { + @unstable(feature = tls) + constructor(server-name: string, input: input-stream, output: output-stream); + + @unstable(feature = tls) + finish: static func(this: client-handshake) -> future-client-streams; + } + + @unstable(feature = tls) + resource client-connection { + } + + @unstable(feature = tls) + resource future-client-streams { + @unstable(feature = tls) + subscribe: func() -> pollable; + + @unstable(feature = tls) + get: func() -> option>>>; + } +} diff --git a/crates/wasi/wit/deps/sockets/world.wit b/crates/wasi/wit/deps/sockets/world.wit index 6e349c756b5e..bfcfab4c2743 100644 --- a/crates/wasi/wit/deps/sockets/world.wit +++ b/crates/wasi/wit/deps/sockets/world.wit @@ -16,4 +16,6 @@ world imports { import tcp-create-socket; @since(version = 0.2.0) import ip-name-lookup; + @unstable(feature = tls) + import tls; } diff --git a/src/common.rs b/src/common.rs index c0e9a7e1a365..2065757c2196 100644 --- a/src/common.rs +++ b/src/common.rs @@ -338,6 +338,7 @@ impl RunCommon { let mut options = LinkOptions::default(); options.cli_exit_with_code(self.common.wasi.cli_exit_with_code.unwrap_or(false)); options.network_error_code(self.common.wasi.network_error_code.unwrap_or(false)); + options.tls(self.common.wasi.tls.unwrap_or(false)); options } }