Skip to content

Commit

Permalink
Refactor UserInput to use channels
Browse files Browse the repository at this point in the history
  • Loading branch information
bakaq committed Jan 30, 2025
1 parent 63f381e commit fbfdcce
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 27 deletions.
33 changes: 13 additions & 20 deletions src/machine/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::cell::RefCell;
use std::io::{Seek, SeekFrom, Write};
use std::rc::Rc;
use std::{borrow::Cow, io::Cursor};
use std::borrow::Cow;
use std::io::Write;
use std::sync::mpsc::{channel, Receiver, Sender};

use rand::{rngs::StdRng, SeedableRng};

Expand Down Expand Up @@ -40,14 +39,12 @@ impl StreamConfig {
///
/// This also returns a handler to the stdin do the [`Machine`](crate::Machine).
pub fn with_callbacks(stdout: Option<Callback>, stderr: Option<Callback>) -> (UserInput, Self) {
let stdin = Rc::new(RefCell::new(Cursor::new(Vec::new())));
let (sender, receiver) = channel();
(
UserInput {
inner: stdin.clone(),
},
UserInput { inner: sender },
StreamConfig {
inner: StreamConfigInner::Callbacks {
stdin,
stdin: receiver,
stdout,
stderr,
},
Expand All @@ -59,23 +56,19 @@ impl StreamConfig {
/// A handler for the stdin of the [`Machine`](crate::Machine).
#[derive(Debug)]
pub struct UserInput {
inner: Rc<RefCell<Cursor<Vec<u8>>>>,
inner: Sender<Vec<u8>>,
}

impl Write for UserInput {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut inner = self.inner.borrow_mut();
let pos = inner.position();

inner.seek(SeekFrom::End(0))?;
let result = inner.write(buf);
inner.seek(SeekFrom::Start(pos))?;

result
self.inner
.send(buf.into())
.map(|_| buf.len())
.map_err(|_| std::io::ErrorKind::BrokenPipe.into())
}

fn flush(&mut self) -> std::io::Result<()> {
self.inner.borrow_mut().flush()
Ok(())
}
}

Expand All @@ -85,7 +78,7 @@ enum StreamConfigInner {
#[default]
Memory,
Callbacks {
stdin: Rc<RefCell<Cursor<Vec<u8>>>>,
stdin: Receiver<Vec<u8>>,
stdout: Option<Callback>,
stderr: Option<Callback>,
},
Expand Down
2 changes: 1 addition & 1 deletion src/machine/lib_machine/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ fn callback_streams() {

let (mut user_input, streams) = StreamConfig::with_callbacks(
Some(Box::new(move |x| {
x.read_to_string(&mut *test_string2.borrow_mut()).unwrap();
x.read_to_string(&mut test_string2.borrow_mut()).unwrap();
})),
None,
);
Expand Down
65 changes: 59 additions & 6 deletions src/machine/streams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pub use scryer_modular_bitfield::prelude::*;

#[cfg(feature = "http")]
use bytes::{buf::Reader as BufReader, Buf, Bytes};
use std::cell::RefCell;
use std::cmp::Ordering;
use std::error::Error;
use std::fmt;
Expand All @@ -30,7 +29,8 @@ use std::net::{Shutdown, TcpStream};
use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
use std::ptr;
use std::rc::Rc;
use std::sync::mpsc::Receiver;
use std::sync::mpsc::TryRecvError;

#[cfg(feature = "tls")]
use native_tls::TlsStream;
Expand Down Expand Up @@ -414,13 +414,50 @@ impl Write for CallbackStream {

#[derive(Debug)]
pub struct InputChannelStream {
pub(crate) inner: Rc<RefCell<Cursor<Vec<u8>>>>,
pub(crate) inner: Cursor<Vec<u8>>,
pub eof: bool,
channel: Receiver<Vec<u8>>,
}

impl Read for InputChannelStream {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.inner.borrow_mut().read(buf)
if self.eof {
return Ok(0);
}

let to_read = buf.len();
let mut total_read = 0;

loop {
total_read += self.inner.read(&mut buf[total_read..])?;

if total_read < to_read {
// We need to get more data to read
match self.channel.try_recv() {
Ok(data) => {
// Append into self.inner
let pos = self.inner.position();
assert_eq!(pos as usize, self.inner.get_ref().len());
self.inner.write_all(&data)?;
self.inner.seek(SeekFrom::Start(pos))?;
}
Err(TryRecvError::Empty) => {
// Data is pending
break;
}
Err(TryRecvError::Disconnected) => {
// The other end of the channel was closed so we are EOF
self.eof = true;
break;
}
}
} else {
assert_eq!(total_read, to_read);
break;
}
}
Ok(total_read)
}
}

Expand Down Expand Up @@ -602,9 +639,14 @@ impl Stream {
}

#[inline]
pub fn input_channel(cursor: Rc<RefCell<Cursor<Vec<u8>>>>, arena: &mut Arena) -> Stream {
pub fn input_channel(channel: Receiver<Vec<u8>>, arena: &mut Arena) -> Stream {
let inner = Cursor::new(Vec::new());
Stream::InputChannel(arena_alloc!(
StreamLayout::new(CharReader::new(InputChannelStream { inner: cursor })),
StreamLayout::new(CharReader::new(InputChannelStream {
inner,
eof: false,
channel
})),
arena
))
}
Expand Down Expand Up @@ -1239,6 +1281,13 @@ impl Stream {
AtEndOfStream::Past
}
}
Stream::InputChannel(stream_layout) => {
if stream_layout.stream.get_ref().eof {
AtEndOfStream::At
} else {
AtEndOfStream::Not
}
}
_ => AtEndOfStream::Not,
}
}
Expand Down Expand Up @@ -1499,6 +1548,10 @@ impl Stream {
readline_stream.reset();
true
}
Stream::InputChannel(ref mut input_channel_stream) => {
input_channel_stream.stream.get_mut().inner.set_position(0);
true
}
_ => false,
}
}
Expand Down

0 comments on commit fbfdcce

Please sign in to comment.