Skip to content

Commit

Permalink
Include the observation in the information being sent to the frontend
Browse files Browse the repository at this point in the history
Use a secure websocket (wss)
  • Loading branch information
RedTachyon committed Feb 29, 2024
1 parent 586e628 commit bb5af0f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
33 changes: 27 additions & 6 deletions cogment_lab/humans/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ def image_to_msg(img: np.ndarray | None) -> str | None:
return f"data:image/png;base64,{base64_encoded_result_str}"


def obs_to_msg(obs: np.ndarray | dict[str, np.ndarray | dict]) -> dict[str, Any]:
if isinstance(obs, np.ndarray):
obs = obs.tolist()
elif isinstance(obs, dict):
# Recursively convert all numpy arrays to lists
obs = {k: obs_to_msg(v) for k, v in obs.items()}
elif isinstance(obs, np.integer):
obs = int(obs)
elif isinstance(obs, np.floating):
obs = float(obs)
return obs


def msg_to_action(data: str, action_map: list[str] | dict[str, int]) -> int:
if isinstance(action_map, list):
action_map = {action: i for i, action in enumerate(action_map)}
Expand Down Expand Up @@ -111,17 +124,25 @@ async def websocket_endpoint(websocket: WebSocket):
while True:
try:
logging.info("Waiting for frame")
frame: np.ndarray = await recv_queue.get()
out: tuple = await recv_queue.get()
obs, frame = out
if not isinstance(frame, np.ndarray):
logging.warning(f"Got frame of type {type(frame)}")
continue
logging.info(f"Got frame with shape {frame.shape}")
msg = image_to_msg(frame)
if msg is not None:
await websocket.send_text(msg)
image_data = image_to_msg(frame)
obs_data = obs_to_msg(obs)

msg = {"observation": obs_data, "image": image_data}

if image_data is not None:
# await websocket.send_text(image_data)
await websocket.send_json(msg)

try:
action_data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0 / fps)
action_data = await asyncio.wait_for(
websocket.receive_text(), timeout=1.0 / fps if fps > 0 else None
)
last_action_data = action_data
logging.info(f"Got action {action_data}, updated {last_action_data=}")
except asyncio.TimeoutError:
Expand Down Expand Up @@ -161,7 +182,7 @@ async def act(self, observation: Any, rendered_frame: np.ndarray | None = None)
# if rendered_frame is not None
# else "no frame"
# )
await self.send_queue.put(rendered_frame)
await self.send_queue.put((observation, rendered_frame))
action = await self.recv_queue.get()
return action

Expand Down
7 changes: 4 additions & 3 deletions cogment_lab/humans/static/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ <h1>RL Env Interface</h1>

document.getElementById('start-button').addEventListener('click', function() {
if (!isConnected) {
socket = new WebSocket("ws://" + location.host + "/ws");
socket = new WebSocket("wss://" + location.host + "/ws");
socket.onmessage = function(event) {
var landerImage = document.getElementById('env-render');
landerImage.src = event.data;
let image = document.getElementById('env-render');
let json = JSON.parse(event.data);
image.src = json['image'];
};
isConnected = true;
this.style.display = 'none'; // Hide the button after it's clicked
Expand Down

0 comments on commit bb5af0f

Please sign in to comment.