diff --git a/cogment_lab/humans/actor.py b/cogment_lab/humans/actor.py index 4bb1649..a6c8c11 100644 --- a/cogment_lab/humans/actor.py +++ b/cogment_lab/humans/actor.py @@ -47,6 +47,19 @@ def image_to_msg(img: np.ndarray | None) -> str | None: return f"data:image/png;base64,{base64_encoded_result_str}" +def obs_to_msg(obs: np.ndarray | dict[str, np.ndarray | dict]) -> dict[str, Any]: + if isinstance(obs, np.ndarray): + obs = obs.tolist() + elif isinstance(obs, dict): + # Recursively convert all numpy arrays to lists + obs = {k: obs_to_msg(v) for k, v in obs.items()} + elif isinstance(obs, np.integer): + obs = int(obs) + elif isinstance(obs, np.floating): + obs = float(obs) + return obs + + def msg_to_action(data: str, action_map: list[str] | dict[str, int]) -> int: if isinstance(action_map, list): action_map = {action: i for i, action in enumerate(action_map)} @@ -111,17 +124,25 @@ async def websocket_endpoint(websocket: WebSocket): while True: try: logging.info("Waiting for frame") - frame: np.ndarray = await recv_queue.get() + out: tuple = await recv_queue.get() + obs, frame = out if not isinstance(frame, np.ndarray): logging.warning(f"Got frame of type {type(frame)}") continue logging.info(f"Got frame with shape {frame.shape}") - msg = image_to_msg(frame) - if msg is not None: - await websocket.send_text(msg) + image_data = image_to_msg(frame) + obs_data = obs_to_msg(obs) + + msg = {"observation": obs_data, "image": image_data} + + if image_data is not None: + # await websocket.send_text(image_data) + await websocket.send_json(msg) try: - action_data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0 / fps) + action_data = await asyncio.wait_for( + websocket.receive_text(), timeout=1.0 / fps if fps > 0 else None + ) last_action_data = action_data logging.info(f"Got action {action_data}, updated {last_action_data=}") except asyncio.TimeoutError: @@ -161,7 +182,7 @@ async def act(self, observation: Any, rendered_frame: np.ndarray | None = None) # if rendered_frame is not None # else "no frame" # ) - await self.send_queue.put(rendered_frame) + await self.send_queue.put((observation, rendered_frame)) action = await self.recv_queue.get() return action diff --git a/cogment_lab/humans/static/index.html b/cogment_lab/humans/static/index.html index bc8c60a..a437a15 100644 --- a/cogment_lab/humans/static/index.html +++ b/cogment_lab/humans/static/index.html @@ -17,10 +17,11 @@