Skip to content

Commit

Permalink
Remove special case for FileInputStream and change InputStream enum t…
Browse files Browse the repository at this point in the history
…o be a type alias, just like OutputStream
  • Loading branch information
badeend committed Aug 11, 2024
1 parent 6da19c8 commit 33d38f9
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 85 deletions.
2 changes: 1 addition & 1 deletion crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ where
let body = self.table().get_mut(&id)?;

if let Some(stream) = body.take_stream() {
let stream = InputStream::Host(Box::new(stream));
let stream: InputStream = Box::new(stream);
let stream = self.table().push_child(stream, &id)?;
return Ok(Ok(stream));
}
Expand Down
3 changes: 0 additions & 3 deletions crates/wasi/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,10 @@ mod async_io {
"[method]descriptor.unlink-file-at",
"[method]descriptor.unlock",
"[method]descriptor.write",
"[method]input-stream.read",
"[method]input-stream.blocking-read",
"[method]input-stream.blocking-skip",
"[method]input-stream.skip",
"[drop]input-stream",
"[method]output-stream.forward",
"[method]output-stream.splice",
"[method]output-stream.blocking-splice",
"[method]output-stream.blocking-flush",
"[method]output-stream.blocking-write",
Expand Down
32 changes: 0 additions & 32 deletions crates/wasi/src/filesystem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,38 +320,6 @@ impl FileInputStream {
self.position += min_len as u64;
chunk
}

pub async fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
use system_interface::fs::FileIoExt;
let p = self.position;

let (r, mut buf) = self
.file
.run_blocking(move |f| {
let mut buf = BytesMut::zeroed(size);
let r = f.read_at(&mut buf, p);
(r, buf)
})
.await;
let n = read_result(r, size)?;
buf.truncate(n);
self.position += n as u64;
Ok(buf.freeze())
}

pub async fn skip(&mut self, nelem: usize) -> Result<usize, StreamError> {
let bs = self.read(nelem).await?;
Ok(bs.len())
}
}

fn read_result(r: io::Result<usize>, size: usize) -> Result<usize, StreamError> {
match r {
Ok(0) if size > 0 => Err(StreamError::Closed),
Ok(n) => Ok(n),
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok(0),
Err(e) => Err(StreamError::LastOperationFailed(e.into())),
}
}
#[async_trait::async_trait]
impl HostInputStream for FileInputStream {
Expand Down
4 changes: 2 additions & 2 deletions crates/wasi/src/host/filesystem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,10 +746,10 @@ where
}

// Create a stream view for it.
let reader = FileInputStream::new(f, offset);
let reader: InputStream = Box::new(FileInputStream::new(f, offset));

// Insert the stream view into the table. Trap if the table is full.
let index = self.table().push(InputStream::Host(Box::new(reader)))?;
let index = self.table().push(reader)?;

Ok(index)
}
Expand Down
44 changes: 13 additions & 31 deletions crates/wasi/src/host/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ where
Ok(())
}

async fn splice(
fn splice(
&mut self,
dest: Resource<OutputStream>,
src: Resource<InputStream>,
Expand All @@ -129,10 +129,7 @@ where
return Ok(0);
}

let contents = match self.table().get_mut(&src)? {
InputStream::Host(h) => h.read(len)?,
InputStream::File(f) => f.read(len).await?,
};
let contents = self.table().get_mut(&src)?.read(len)?;

let len = contents.len();
if len == 0 {
Expand All @@ -156,7 +153,7 @@ where

self.table().get_mut(&src)?.ready().await;

self.splice(dest, src, len).await
self.splice(dest, src, len)
}
}

Expand All @@ -166,19 +163,13 @@ where
T: WasiView,
{
async fn drop(&mut self, stream: Resource<InputStream>) -> anyhow::Result<()> {
match self.table().delete(stream)? {
InputStream::Host(mut s) => s.cancel().await,
InputStream::File(_) => {}
}
self.table().delete(stream)?.cancel().await;
Ok(())
}

async fn read(&mut self, stream: Resource<InputStream>, len: u64) -> StreamResult<Vec<u8>> {
fn read(&mut self, stream: Resource<InputStream>, len: u64) -> StreamResult<Vec<u8>> {
let len = len.try_into().unwrap_or(usize::MAX);
let bytes = match self.table().get_mut(&stream)? {
InputStream::Host(s) => s.read(len)?,
InputStream::File(s) => s.read(len).await?,
};
let bytes = self.table().get_mut(&stream)?.read(len)?;
debug_assert!(bytes.len() <= len);
Ok(bytes.into())
}
Expand All @@ -189,20 +180,14 @@ where
len: u64,
) -> StreamResult<Vec<u8>> {
let len = len.try_into().unwrap_or(usize::MAX);
let bytes = match self.table().get_mut(&stream)? {
InputStream::Host(s) => s.blocking_read(len).await?,
InputStream::File(s) => s.read(len).await?,
};
let bytes = self.table().get_mut(&stream)?.blocking_read(len).await?;
debug_assert!(bytes.len() <= len);
Ok(bytes.into())
}

async fn skip(&mut self, stream: Resource<InputStream>, len: u64) -> StreamResult<u64> {
fn skip(&mut self, stream: Resource<InputStream>, len: u64) -> StreamResult<u64> {
let len = len.try_into().unwrap_or(usize::MAX);
let written = match self.table().get_mut(&stream)? {
InputStream::Host(s) => s.skip(len)?,
InputStream::File(s) => s.skip(len).await?,
};
let written = self.table().get_mut(&stream)?.skip(len)?;
Ok(written.try_into().expect("usize always fits in u64"))
}

Expand All @@ -212,10 +197,7 @@ where
len: u64,
) -> StreamResult<u64> {
let len = len.try_into().unwrap_or(usize::MAX);
let written = match self.table().get_mut(&stream)? {
InputStream::Host(s) => s.blocking_skip(len).await?,
InputStream::File(s) => s.skip(len).await?,
};
let written = self.table().get_mut(&stream)?.blocking_skip(len).await?;
Ok(written.try_into().expect("usize always fits in u64"))
}

Expand Down Expand Up @@ -325,7 +307,7 @@ pub mod sync {
src: Resource<InputStream>,
len: u64,
) -> StreamResult<u64> {
in_tokio(async { AsyncHostOutputStream::splice(self, dst, src, len).await })
AsyncHostOutputStream::splice(self, dst, src, len)
}

fn blocking_splice(
Expand All @@ -347,7 +329,7 @@ pub mod sync {
}

fn read(&mut self, stream: Resource<InputStream>, len: u64) -> StreamResult<Vec<u8>> {
in_tokio(async { AsyncHostInputStream::read(self, stream, len).await })
AsyncHostInputStream::read(self, stream, len)
}

fn blocking_read(
Expand All @@ -359,7 +341,7 @@ pub mod sync {
}

fn skip(&mut self, stream: Resource<InputStream>, len: u64) -> StreamResult<u64> {
in_tokio(async { AsyncHostInputStream::skip(self, stream, len).await })
AsyncHostInputStream::skip(self, stream, len)
}

fn blocking_skip(&mut self, stream: Resource<InputStream>, len: u64) -> StreamResult<u64> {
Expand Down
2 changes: 1 addition & 1 deletion crates/wasi/src/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ where
{
fn get_stdin(&mut self) -> Result<Resource<streams::InputStream>, anyhow::Error> {
let stream = self.ctx().stdin.stream();
Ok(self.table().push(streams::InputStream::Host(stream))?)
Ok(self.table().push(stream)?)
}
}

Expand Down
16 changes: 4 additions & 12 deletions crates/wasi/src/stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::filesystem::FileInputStream;
use crate::poll::Subscribe;
use anyhow::Result;
use bytes::Bytes;
Expand Down Expand Up @@ -254,20 +253,13 @@ impl Subscribe for Box<dyn HostOutputStream> {
}
}

pub enum InputStream {
Host(Box<dyn HostInputStream>),
File(FileInputStream),
}

#[async_trait::async_trait]
impl Subscribe for InputStream {
impl Subscribe for Box<dyn HostInputStream> {
async fn ready(&mut self) {
match self {
InputStream::Host(stream) => stream.ready().await,
// Files are always ready
InputStream::File(_) => {}
}
(**self).ready().await
}
}

pub type InputStream = Box<dyn HostInputStream>;

pub type OutputStream = Box<dyn HostOutputStream>;
5 changes: 2 additions & 3 deletions crates/wasi/src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,7 @@ impl TcpSocket {
Ok(stream) => {
let stream = Arc::new(stream);
self.tcp_state = TcpState::Connected(stream.clone());
let input: InputStream =
InputStream::Host(Box::new(TcpReadStream::new(stream.clone())));
let input: InputStream = Box::new(TcpReadStream::new(stream.clone()));
let output: OutputStream = Box::new(TcpWriteStream::new(stream));
Ok((input, output))
}
Expand Down Expand Up @@ -428,7 +427,7 @@ impl TcpSocket {

let client = Arc::new(client);

let input: InputStream = InputStream::Host(Box::new(TcpReadStream::new(client.clone())));
let input: InputStream = Box::new(TcpReadStream::new(client.clone()));
let output: OutputStream = Box::new(TcpWriteStream::new(client.clone()));
let tcp_socket = TcpSocket::from_state(TcpState::Connected(client), self.family)?;

Expand Down

0 comments on commit 33d38f9

Please sign in to comment.