Skip to content

Commit

Permalink
Update the web UI (#18)
Browse files Browse the repository at this point in the history
Environment observations are now sent to the web UI.

Web UI automatically detects ws:// or wss:// (to work locally and over
the web)
  • Loading branch information
RedTachyon authored Mar 1, 2024
1 parent 586e628 commit 1d147e3
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 91 deletions.
33 changes: 27 additions & 6 deletions cogment_lab/humans/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ def image_to_msg(img: np.ndarray | None) -> str | None:
return f"data:image/png;base64,{base64_encoded_result_str}"


def obs_to_msg(obs: np.ndarray | dict[str, np.ndarray | dict]) -> dict[str, Any]:
if isinstance(obs, np.ndarray):
obs = obs.tolist()
elif isinstance(obs, dict):
# Recursively convert all numpy arrays to lists
obs = {k: obs_to_msg(v) for k, v in obs.items()}
elif isinstance(obs, np.integer):
obs = int(obs)
elif isinstance(obs, np.floating):
obs = float(obs)
return obs


def msg_to_action(data: str, action_map: list[str] | dict[str, int]) -> int:
if isinstance(action_map, list):
action_map = {action: i for i, action in enumerate(action_map)}
Expand Down Expand Up @@ -111,17 +124,25 @@ async def websocket_endpoint(websocket: WebSocket):
while True:
try:
logging.info("Waiting for frame")
frame: np.ndarray = await recv_queue.get()
out: tuple = await recv_queue.get()
obs, frame = out
if not isinstance(frame, np.ndarray):
logging.warning(f"Got frame of type {type(frame)}")
continue
logging.info(f"Got frame with shape {frame.shape}")
msg = image_to_msg(frame)
if msg is not None:
await websocket.send_text(msg)
image_data = image_to_msg(frame)
obs_data = obs_to_msg(obs)

msg = {"observation": obs_data, "image": image_data}

if image_data is not None:
# await websocket.send_text(image_data)
await websocket.send_json(msg)

try:
action_data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0 / fps)
action_data = await asyncio.wait_for(
websocket.receive_text(), timeout=1.0 / fps if fps > 0 else None
)
last_action_data = action_data
logging.info(f"Got action {action_data}, updated {last_action_data=}")
except asyncio.TimeoutError:
Expand Down Expand Up @@ -161,7 +182,7 @@ async def act(self, observation: Any, rendered_frame: np.ndarray | None = None)
# if rendered_frame is not None
# else "no frame"
# )
await self.send_queue.put(rendered_frame)
await self.send_queue.put((observation, rendered_frame))
action = await self.recv_queue.get()
return action

Expand Down
9 changes: 6 additions & 3 deletions cogment_lab/humans/static/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ <h1>RL Env Interface</h1>

document.getElementById('start-button').addEventListener('click', function() {
if (!isConnected) {
socket = new WebSocket("ws://" + location.host + "/ws");
const wsProtocol = location.protocol === 'https:' ? 'wss://' : 'ws://';

socket = new WebSocket(wsProtocol + location.host + "/ws");
socket.onmessage = function(event) {
var landerImage = document.getElementById('env-render');
landerImage.src = event.data;
let image = document.getElementById('env-render');
let json = JSON.parse(event.data);
image.src = json['image'];
};
isConnected = true;
this.style.display = 'none'; // Hide the button after it's clicked
Expand Down
Loading

0 comments on commit 1d147e3

Please sign in to comment.