From 1d147e3e70d78ab974db79f7056ef94aea854668 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Sat, 2 Mar 2024 00:34:03 +0100 Subject: [PATCH] Update the web UI (#18) Environment observations are now sent to the web UI. Web UI automatically detects ws:// or wss:// (to work locally and over the web) --- cogment_lab/humans/actor.py | 33 +++++- cogment_lab/humans/static/index.html | 9 +- tutorials/2-basic-gym.ipynb | 169 ++++++++++++++------------- 3 files changed, 120 insertions(+), 91 deletions(-) diff --git a/cogment_lab/humans/actor.py b/cogment_lab/humans/actor.py index 4bb1649..a6c8c11 100644 --- a/cogment_lab/humans/actor.py +++ b/cogment_lab/humans/actor.py @@ -47,6 +47,19 @@ def image_to_msg(img: np.ndarray | None) -> str | None: return f"data:image/png;base64,{base64_encoded_result_str}" +def obs_to_msg(obs: np.ndarray | dict[str, np.ndarray | dict]) -> dict[str, Any]: + if isinstance(obs, np.ndarray): + obs = obs.tolist() + elif isinstance(obs, dict): + # Recursively convert all numpy arrays to lists + obs = {k: obs_to_msg(v) for k, v in obs.items()} + elif isinstance(obs, np.integer): + obs = int(obs) + elif isinstance(obs, np.floating): + obs = float(obs) + return obs + + def msg_to_action(data: str, action_map: list[str] | dict[str, int]) -> int: if isinstance(action_map, list): action_map = {action: i for i, action in enumerate(action_map)} @@ -111,17 +124,25 @@ async def websocket_endpoint(websocket: WebSocket): while True: try: logging.info("Waiting for frame") - frame: np.ndarray = await recv_queue.get() + out: tuple = await recv_queue.get() + obs, frame = out if not isinstance(frame, np.ndarray): logging.warning(f"Got frame of type {type(frame)}") continue logging.info(f"Got frame with shape {frame.shape}") - msg = image_to_msg(frame) - if msg is not None: - await websocket.send_text(msg) + image_data = image_to_msg(frame) + obs_data = obs_to_msg(obs) + + msg = {"observation": obs_data, "image": image_data} + + if image_data is not None: + # await websocket.send_text(image_data) + await websocket.send_json(msg) try: - action_data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0 / fps) + action_data = await asyncio.wait_for( + websocket.receive_text(), timeout=1.0 / fps if fps > 0 else None + ) last_action_data = action_data logging.info(f"Got action {action_data}, updated {last_action_data=}") except asyncio.TimeoutError: @@ -161,7 +182,7 @@ async def act(self, observation: Any, rendered_frame: np.ndarray | None = None) # if rendered_frame is not None # else "no frame" # ) - await self.send_queue.put(rendered_frame) + await self.send_queue.put((observation, rendered_frame)) action = await self.recv_queue.get() return action diff --git a/cogment_lab/humans/static/index.html b/cogment_lab/humans/static/index.html index bc8c60a..6f9302e 100644 --- a/cogment_lab/humans/static/index.html +++ b/cogment_lab/humans/static/index.html @@ -17,10 +17,13 @@

RL Env Interface

document.getElementById('start-button').addEventListener('click', function() { if (!isConnected) { - socket = new WebSocket("ws://" + location.host + "/ws"); + const wsProtocol = location.protocol === 'https:' ? 'wss://' : 'ws://'; + + socket = new WebSocket(wsProtocol + location.host + "/ws"); socket.onmessage = function(event) { - var landerImage = document.getElementById('env-render'); - landerImage.src = event.data; + let image = document.getElementById('env-render'); + let json = JSON.parse(event.data); + image.src = json['image']; }; isConnected = true; this.style.display = 'none'; // Hide the button after it's clicked diff --git a/tutorials/2-basic-gym.ipynb b/tutorials/2-basic-gym.ipynb index f1a1209..023230c 100644 --- a/tutorials/2-basic-gym.ipynb +++ b/tutorials/2-basic-gym.ipynb @@ -49,8 +49,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:01.184731Z", - "start_time": "2024-01-22T12:05:00.795408Z" + "end_time": "2024-02-29T12:54:06.130371Z", + "start_time": "2024-02-29T12:54:05.713664Z" } }, "id": "6a8c350e03c758af", @@ -75,8 +75,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:01.188249Z", - "start_time": "2024-01-22T12:05:01.185756Z" + "end_time": "2024-02-29T12:54:06.135368Z", + "start_time": "2024-02-29T12:54:06.131122Z" } }, "id": "ea3f3c96625e869", @@ -86,7 +86,7 @@ "cell_type": "markdown", "source": [ "Let's launch an environment. You can use any Gymnasium or PettingZoo environments. \n", - "In this tutorial, we focus on Gymnasium, and use the `CartPole-v1` environment." + "In this tutorial, we focus on Gymnasium, and use the `LunarLander-v2` environment." ], "metadata": { "collapsed": false @@ -110,18 +110,18 @@ "outputs": [], "source": [ "cenv = GymEnvironment(\n", - " env_id=\"CartPole-v1\", # Environment ID, as registered in Gymnasium\n", + " env_id=\"LunarLander-v2\", # Environment ID, as registered in Gymnasium\n", " render=True, # True if we want to ever render the environment; requires pygame\n", ")" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:01.191740Z", - "start_time": "2024-01-22T12:05:01.188388Z" + "end_time": "2024-02-29T12:54:06.212527Z", + "start_time": "2024-02-29T12:54:06.133954Z" } }, - "id": "2adc89a51e6606ff", + "id": "6b13837f79d5419f", "execution_count": 3 }, { @@ -132,7 +132,7 @@ "metadata": { "collapsed": false }, - "id": "b7fafd1f0180c7e3" + "id": "efdc5e4ec72c6e49" }, { "cell_type": "code", @@ -148,7 +148,7 @@ ], "source": [ "await cog.run_env(cenv, \n", - " env_name=\"cartpole\", # Unique name for the environment \n", + " env_name=\"lunar\", # Unique name for the environment \n", " port=9011, # Port through which the env communicates with Cogment; has to be free and unique\n", " log_file=\"env.log\" # File to which the environment logs are written\n", ")" @@ -156,11 +156,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:01.632Z", - "start_time": "2024-01-22T12:05:01.192493Z" + "end_time": "2024-02-29T12:54:06.653102Z", + "start_time": "2024-02-29T12:54:06.213422Z" } }, - "id": "8f41d4a78c5fd3cf", + "id": "e5e10ff78a0868dc", "execution_count": 4 }, { @@ -171,7 +171,7 @@ "metadata": { "collapsed": false }, - "id": "40541f3a144313d6" + "id": "37e96d70e7916ae8" }, { "cell_type": "code", @@ -180,8 +180,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Observation space: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)\n", - "Action space: Discrete(2)\n" + "Observation space: Box([-1.5 -1.5 -5. -5. -3.1415927 -5.\n", + " -0. -0. ], [1.5 1.5 5. 5. 3.1415927 5. 1.\n", + " 1. ], (8,), float32)\n", + "Action space: Discrete(4)\n" ] } ], @@ -196,11 +198,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:01.660873Z", - "start_time": "2024-01-22T12:05:01.651694Z" + "end_time": "2024-02-29T12:54:06.656505Z", + "start_time": "2024-02-29T12:54:06.652838Z" } }, - "id": "5420cccfd23809c5", + "id": "ca57199fb45e38ef", "execution_count": 5 }, { @@ -211,7 +213,7 @@ "metadata": { "collapsed": false }, - "id": "b5e7f6898ae2b5ce" + "id": "2ea8b53bb2d4b862" }, { "cell_type": "markdown", @@ -223,7 +225,7 @@ "metadata": { "collapsed": false }, - "id": "2b6504629bbb5692" + "id": "543732e1926b1460" }, { "cell_type": "code", @@ -234,11 +236,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:01.661371Z", - "start_time": "2024-01-22T12:05:01.653933Z" + "end_time": "2024-02-29T12:54:06.657051Z", + "start_time": "2024-02-29T12:54:06.655301Z" } }, - "id": "3c3e2fa439f0ed4d", + "id": "e886a553e6edfcdd", "execution_count": 6 }, { @@ -249,7 +251,7 @@ "metadata": { "collapsed": false }, - "id": "33384582e53f70ba" + "id": "26743200a23d1d0b" }, { "cell_type": "code", @@ -273,11 +275,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:02.173737Z", - "start_time": "2024-01-22T12:05:01.656401Z" + "end_time": "2024-02-29T12:54:07.804480Z", + "start_time": "2024-02-29T12:54:06.657829Z" } }, - "id": "d7207f005843178f", + "id": "4fdddb4dd638f5fa", "execution_count": 7 }, { @@ -288,7 +290,7 @@ "metadata": { "collapsed": false }, - "id": "9a074e24d3049bc2" + "id": "28d2903e8de7de7b" }, { "cell_type": "code", @@ -314,11 +316,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:02.695609Z", - "start_time": "2024-01-22T12:05:02.173262Z" + "end_time": "2024-02-29T12:54:07.854839Z", + "start_time": "2024-02-29T12:54:07.078852Z" } }, - "id": "b3af97cae58a5e09", + "id": "ec5d4e81c7c0ed18", "execution_count": 8 }, { @@ -331,14 +333,14 @@ "metadata": { "collapsed": false }, - "id": "dbd2ff3ab5966087" + "id": "1ddc1bb0895db85c" }, { "cell_type": "code", "outputs": [], "source": [ "trial_id = await cog.start_trial(\n", - " env_name=\"cartpole\", # Name of the environment to use\n", + " env_name=\"lunar\", # Name of the environment to use\n", " actor_impls={\n", " \"gym\": \"random\", # Name of the actor to use. For Gymnasium environments, the key is always \"gym\"\n", " },\n", @@ -348,11 +350,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:02.741119Z", - "start_time": "2024-01-22T12:05:02.696368Z" + "end_time": "2024-02-29T12:54:07.855183Z", + "start_time": "2024-02-29T12:54:07.493474Z" } }, - "id": "9091ef7c5c2c6ebb", + "id": "c2a1f766961e0b0c", "execution_count": 9 }, { @@ -363,7 +365,7 @@ "metadata": { "collapsed": false }, - "id": "b4c8c7f29f7eb888" + "id": "fd75f08e59fe2590" }, { "cell_type": "code", @@ -372,11 +374,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "Observation shape: (38, 4)\n", - "Action shape: (38,)\n", - "Reward shape: (38,)\n", - "Done shape: (38,)\n", - "Next observation shape: (38, 4)\n" + "Observation shape: (65, 8)\n", + "Action shape: (65,)\n", + "Reward shape: (65,)\n", + "Done shape: (65,)\n", + "Next observation shape: (65, 8)\n" ] } ], @@ -390,11 +392,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:02.741622Z", - "start_time": "2024-01-22T12:05:02.737084Z" + "end_time": "2024-02-29T12:54:07.855518Z", + "start_time": "2024-02-29T12:54:07.628418Z" } }, - "id": "87743a3a193d2f5", + "id": "15c4c46964bbd015", "execution_count": 10 }, { @@ -405,7 +407,7 @@ "metadata": { "collapsed": false }, - "id": "f052de4d8eed7a1e" + "id": "9e48b20432a2dc62" }, { "cell_type": "markdown", @@ -415,7 +417,7 @@ "metadata": { "collapsed": false }, - "id": "361788887641254b" + "id": "ae5eeb412208029e" }, { "cell_type": "code", @@ -424,8 +426,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "[0 1 0 0 0 1 0 1 0 1 1 0 1 0 1 1 0 1 1 0 1 0 0 1 0 0 1 1 0 0 1 1 1 1 0 0 1\n", - " 0]\n" + "[3 3 1 2 0 1 1 0 3 1 2 2 1 1 2 1 0 0 2 3 2 3 3 0 1 3 2 3 2 3 0 1 2 1 0 2 1\n", + " 1 2 2 2 1 1 0 1 2 2 0 1 0 2 3 1 1 0 3 3 2 1 1 2 2 0 1 0]\n" ] } ], @@ -435,11 +437,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:02.741901Z", - "start_time": "2024-01-22T12:05:02.739467Z" + "end_time": "2024-02-29T12:54:07.855826Z", + "start_time": "2024-02-29T12:54:07.630661Z" } }, - "id": "7bac6d8ceb1ca4b9", + "id": "e449a5c8c7f1155b", "execution_count": 11 }, { @@ -450,14 +452,14 @@ "metadata": { "collapsed": false }, - "id": "fc1fd882c72eca1e" + "id": "a7e205ac71786beb" }, { "cell_type": "code", "outputs": [], "source": [ "trial_id = await cog.start_trial(\n", - " env_name=\"cartpole\",\n", + " env_name=\"lunar\",\n", " actor_impls={\n", " \"gym\": \"constant\",\n", " },\n", @@ -467,11 +469,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:02.757365Z", - "start_time": "2024-01-22T12:05:02.741947Z" + "end_time": "2024-02-29T12:54:07.856144Z", + "start_time": "2024-02-29T12:54:07.632968Z" } }, - "id": "cc6e1aebec673d27", + "id": "3c3dd1a86783f0aa", "execution_count": 12 }, { @@ -482,7 +484,7 @@ "metadata": { "collapsed": false }, - "id": "2b726377076beee4" + "id": "a07f23c011f0362d" }, { "cell_type": "code", @@ -491,7 +493,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "[0 0 0 0 0 0 0 0 0 0 0]\n" + "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n" ] } ], @@ -501,11 +504,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:02.775739Z", - "start_time": "2024-01-22T12:05:02.756588Z" + "end_time": "2024-02-29T12:54:07.856472Z", + "start_time": "2024-02-29T12:54:07.669279Z" } }, - "id": "777ae3b4235af5e1", + "id": "de075c1897c7ab30", "execution_count": 13 }, { @@ -518,7 +521,7 @@ "metadata": { "collapsed": false }, - "id": "8aff1e6d3a2a6649" + "id": "b73061f48f4e8f7a" }, { "cell_type": "code", @@ -537,8 +540,10 @@ " app_port=8000, # Port through which the web UI is accessible\n", " cogment_port=8999, # Port through which the web UI communicates with Cogment; has to be free and unique\n", " actions=[\n", - " \"no-op\", # If nothing is pressed, the action is 0, i.e. the index of \"no-op\" \n", - " \"ArrowRight\" # If the right arrow is pressed, the action is 1, i.e. the index of \"ArrowRight\"\n", + " \"no-op\",\n", + " \"ArrowRight\",\n", + " \"ArrowUp\",\n", + " \"ArrowLeft\"\n", " ],\n", " log_file=\"human.log\"\n", ")" @@ -546,11 +551,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:03.286519Z", - "start_time": "2024-01-22T12:05:02.758783Z" + "end_time": "2024-02-29T12:54:08.190822Z", + "start_time": "2024-02-29T12:54:07.671633Z" } }, - "id": "899fa9b018b34048", + "id": "4389246fa64ef1cd", "execution_count": 14 }, { @@ -561,14 +566,14 @@ "metadata": { "collapsed": false }, - "id": "acaa190815aa5b57" + "id": "98544818ce7647de" }, { "cell_type": "code", "outputs": [], "source": [ "trial_id = await cog.start_trial(\n", - " env_name=\"cartpole\",\n", + " env_name=\"lunar\",\n", " session_config={\"render\": True}, # Tell cogment that we want to use the renders of the environment\n", " actor_impls={\n", " \"gym\": \"web_ui\",\n", @@ -580,11 +585,11 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:10.569759Z", - "start_time": "2024-01-22T12:05:03.287684Z" + "end_time": "2024-02-29T12:54:27.987726Z", + "start_time": "2024-02-29T12:54:08.189816Z" } }, - "id": "dac6af7393d0f866", + "id": "f201c3efb8fcaec1", "execution_count": 15 }, { @@ -592,9 +597,9 @@ "source": [ "You may see that the cell above is still running. This is because Cogment is waiting for the human - you!\n", "\n", - "Open your browser, and go to `http://localhost:8000`. Click the Start button, and then press (and depress) your right arrow to try to balance the cartpole.\n", + "Open your browser, and go to `http://localhost:8000`. Click the Start button, and then use the arrow keys to try and land the lunar lander.\n", "\n", - "After it inevitably falls (or times out, if you're good), go back here and check your result below:" + "After it inevitably crashes (or lands, if you're good), go back here and check your result below:" ], "metadata": { "collapsed": false @@ -608,7 +613,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Your reward is 13.0\n" + "Your reward is 305.34600830078125\n" ] } ], @@ -618,8 +623,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:10.573865Z", - "start_time": "2024-01-22T12:05:10.570277Z" + "end_time": "2024-02-29T12:54:27.992004Z", + "start_time": "2024-02-29T12:54:27.988091Z" } }, "id": "8a9e4fc3efc77d72", @@ -646,8 +651,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-22T12:05:12.609152Z", - "start_time": "2024-01-22T12:05:10.572592Z" + "end_time": "2024-02-29T12:54:29.014194Z", + "start_time": "2024-02-29T12:54:27.990580Z" } }, "id": "4447e862eb37c7c9",