diff --git a/crates/bh_agent_client/src/bindings.rs b/crates/bh_agent_client/src/bindings.rs index c18b1bf..987def0 100644 --- a/crates/bh_agent_client/src/bindings.rs +++ b/crates/bh_agent_client/src/bindings.rs @@ -383,6 +383,19 @@ impl BhAgentClient { ) } + fn file_set_blocking(&self, env_id: EnvironmentId, fd: FileId, blocking: bool) -> PyResult<()> { + debug!( + "Setting file blocking for environment {}, fd {}, blocking {}", + env_id, fd, blocking + ); + + run_in_runtime( + self, + self.client + .file_set_blocking(context::current(), env_id, fd, blocking), + ) + } + fn chown( &self, env_id: EnvironmentId, diff --git a/crates/bh_agent_common/src/service.rs b/crates/bh_agent_common/src/service.rs index 73e32f3..b2b09f1 100644 --- a/crates/bh_agent_common/src/service.rs +++ b/crates/bh_agent_common/src/service.rs @@ -87,6 +87,12 @@ pub trait BhAgentService { async fn file_write(env_id: EnvironmentId, fd: FileId, data: Vec) -> Result<(), AgentError>; + async fn file_set_blocking( + env_id: EnvironmentId, + fd: FileId, + blocking: bool, + ) -> Result<(), AgentError>; + async fn chown( env_id: EnvironmentId, path: String, diff --git a/crates/bh_agent_server/src/server.rs b/crates/bh_agent_server/src/server.rs index bbd85b6..d8cda39 100644 --- a/crates/bh_agent_server/src/server.rs +++ b/crates/bh_agent_server/src/server.rs @@ -14,7 +14,7 @@ use bh_agent_common::{AgentError::*, UserId}; use crate::state::BhAgentState; #[cfg(target_family = "unix")] -use crate::util::{chmod, chown, stat}; +use crate::util::{chmod, chown, set_blocking, stat}; use crate::util::{read_generic, read_lines}; macro_rules! check_env_id { @@ -285,6 +285,27 @@ impl BhAgentService for BhAgentServer { ) } + type FileSetBlockingFut = Ready>; + fn file_set_blocking( + self, + _: Context, + env_id: EnvironmentId, + fd: FileId, + blocking: bool, + ) -> Self::FileSetBlockingFut { + check_env_id!(env_id); + + #[cfg(target_family = "unix")] + return ready( + self.state + .do_mut_operation(&fd, |file| set_blocking(file, blocking)) + .map(|_| ()), + ); + + #[cfg(not(target_family = "unix"))] + return ready(Err(AgentError::UnsupportedPlatform)); + } + type ChownFut = Ready>; fn chown( self, diff --git a/crates/bh_agent_server/src/util/mod.rs b/crates/bh_agent_server/src/util/mod.rs index 900681f..0a8af1d 100644 --- a/crates/bh_agent_server/src/util/mod.rs +++ b/crates/bh_agent_server/src/util/mod.rs @@ -1,9 +1,13 @@ mod read_chars; mod read_lines; #[cfg(target_family = "unix")] +mod set_blocking; +#[cfg(target_family = "unix")] mod unix_functions; pub use read_chars::*; pub use read_lines::read_lines; #[cfg(target_family = "unix")] +pub use set_blocking::set_blocking; +#[cfg(target_family = "unix")] pub use unix_functions::{chmod, chown, stat}; diff --git a/crates/bh_agent_server/src/util/set_blocking.rs b/crates/bh_agent_server/src/util/set_blocking.rs new file mode 100644 index 0000000..f47d0fc --- /dev/null +++ b/crates/bh_agent_server/src/util/set_blocking.rs @@ -0,0 +1,24 @@ +use nix::libc::{fcntl, F_GETFL, F_SETFL, O_NONBLOCK}; +use std::io; +use std::os::unix::io::AsRawFd; + +// TODO: This is using libc directly, but nix has a wrapper for this. We should use that instead. +pub fn set_blocking(fd: &T, blocking: bool) -> io::Result<()> { + let raw_fd = fd.as_raw_fd(); + let flags = unsafe { fcntl(raw_fd, F_GETFL, 0) }; + if flags < 0 { + return Err(io::Error::last_os_error()); + } + + let flags = if blocking { + flags & !O_NONBLOCK + } else { + flags | O_NONBLOCK + }; + let res = unsafe { fcntl(raw_fd, F_SETFL, flags) }; + if res != 0 { + return Err(io::Error::last_os_error()); + } + + Ok(()) +} diff --git a/python/bh_agent_client.pyi b/python/bh_agent_client.pyi index 8f5bb04..593ac58 100644 --- a/python/bh_agent_client.pyi +++ b/python/bh_agent_client.pyi @@ -43,6 +43,7 @@ class BhAgentClient: def file_tell(self, env_id: int, fd: int) -> int: ... def file_is_writable(self, env_id: int, fd: int) -> bool: ... def file_write(self, env_id: int, fd: int, data: bytes) -> int: ... + def file_set_blocking(self, env_id: int, fd: int, blocking: bool) -> None: ... def chown(self, env_id: int, path: str, user: str, group: str) -> None: ... def chmod(self, env_id: int, path: str, mode: int) -> None: ... def stat(self, env_id: int, path: str) -> FileStat: ... diff --git a/python/binharness/agentenvironment.py b/python/binharness/agentenvironment.py index c716023..cb114dc 100644 --- a/python/binharness/agentenvironment.py +++ b/python/binharness/agentenvironment.py @@ -91,6 +91,10 @@ def writelines(self: AgentIO, lines: list[bytes]) -> None: """Write lines to the file.""" self._client.file_write(self._environment_id, self._fd, b"\n".join(lines)) + def set_blocking(self: AgentIO, blocking: bool) -> None: # noqa: FBT001 + """Set the file to non-blocking mode.""" + self._client.file_set_blocking(self._environment_id, self._fd, blocking) + class AgentProcess(Process): """A process running in an agent environment.""" diff --git a/python/binharness/localenvironment.py b/python/binharness/localenvironment.py index 988034f..8d013e8 100644 --- a/python/binharness/localenvironment.py +++ b/python/binharness/localenvironment.py @@ -2,19 +2,93 @@ from __future__ import annotations +import fcntl +import os import shutil import subprocess import tempfile +import typing from pathlib import Path -from typing import TYPE_CHECKING, AnyStr, Sequence +from typing import AnyStr, Sequence from binharness.types.environment import Environment +from binharness.types.io import IO from binharness.types.process import Process from binharness.types.stat import FileStat from binharness.util import normalize_args -if TYPE_CHECKING: - from binharness.types.io import IO + +class LocalIO(IO[AnyStr]): + """A file-like object for the local environment.""" + + inner: typing.IO[AnyStr] + + def __init__(self: LocalIO[AnyStr], inner: typing.IO[AnyStr]) -> None: + """Create a LocalIO.""" + self.inner = inner + + def close(self: LocalIO[AnyStr]) -> None: + """Close the file.""" + return self.inner.close() + + @property + def closed(self: LocalIO[AnyStr]) -> bool: + """Whether the file is closed.""" + return self.inner.closed + + def flush(self: LocalIO[AnyStr]) -> None: + """Flush the file.""" + return self.inner.flush() + + def read(self: LocalIO[AnyStr], n: int = -1) -> AnyStr: + """Read n bytes from the file.""" + return self.inner.read(n) + + def readable(self: LocalIO[AnyStr]) -> bool: + """Whether the file is readable.""" + return self.inner.readable() + + def readline(self: LocalIO[AnyStr], limit: int = -1) -> AnyStr: + """Read a line from the file.""" + return self.inner.readline(limit) + + def readlines(self: LocalIO[AnyStr], hint: int = -1) -> list[AnyStr]: + """Read lines from the file.""" + return self.inner.readlines(hint) + + def seek(self: LocalIO[AnyStr], offset: int, whence: int = 0) -> int | None: + """Seek to a position in the file.""" + return self.inner.seek(offset, whence) + + def seekable(self: LocalIO[AnyStr]) -> bool: + """Whether the file is seekable.""" + return self.inner.seekable() + + def tell(self: LocalIO[AnyStr]) -> int: + """Get the current position in the file.""" + return self.inner.tell() + + def writable(self: LocalIO[AnyStr]) -> bool: + """Whether the file is writable.""" + return self.inner.writable() + + def write(self: LocalIO[AnyStr], s: AnyStr) -> int | None: + """Write to the file.""" + return self.inner.write(s) + + def writelines(self: LocalIO[AnyStr], lines: list[AnyStr]) -> None: + """Write lines to the file.""" + self.inner.writelines(lines) + + def set_blocking(self: LocalIO[AnyStr], blocking: bool) -> None: # noqa: FBT001 + """Set the file to non-blocking mode.""" + fd = self.inner.fileno() + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + if blocking: + flags &= ~os.O_NONBLOCK + else: + flags |= os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) class LocalEnvironment(Environment): @@ -70,7 +144,7 @@ def get_tempdir(self: LocalEnvironment) -> Path: def open_file(self: Environment, path: Path, mode: str) -> IO[AnyStr]: """Open a file in the environment. Follows the same semantics as `open`.""" - return Path.open(path, mode) + return LocalIO(Path.open(path, mode)) def chown(self: Environment, path: Path, user: str, group: str) -> None: """Change the owner of a file.""" @@ -112,17 +186,23 @@ def __init__( @property def stdin(self: LocalProcess) -> IO[bytes] | None: """Get the standard input stream of the process.""" - return self.popen.stdin + if self.popen.stdin is not None: + return LocalIO(self.popen.stdin) + return None @property def stdout(self: LocalProcess) -> IO[bytes] | None: """Get the standard output stream of the process.""" - return self.popen.stdout + if self.popen.stdout is not None: + return LocalIO(self.popen.stdout) + return None @property def stderr(self: LocalProcess) -> IO[bytes] | None: """Get the standard error stream of the process.""" - return self.popen.stderr + if self.popen.stderr is not None: + return LocalIO(self.popen.stderr) + return None @property def returncode(self: LocalProcess) -> int | None: diff --git a/python/binharness/types/io.py b/python/binharness/types/io.py index 8cf8689..1dedc4e 100644 --- a/python/binharness/types/io.py +++ b/python/binharness/types/io.py @@ -64,3 +64,7 @@ def __exit__( """Exit the runtime context and close the file if it's open.""" if not self.closed: self.close() + + def set_blocking(self: IO[AnyStr], blocking: bool) -> None: # noqa: FBT001 + """Set the file to blocking or non-blocking mode.""" + raise NotImplementedError