diff --git a/cogment_lab/humans/actor.py b/cogment_lab/humans/actor.py
index 4bb1649..a6c8c11 100644
--- a/cogment_lab/humans/actor.py
+++ b/cogment_lab/humans/actor.py
@@ -47,6 +47,19 @@ def image_to_msg(img: np.ndarray | None) -> str | None:
return f"data:image/png;base64,{base64_encoded_result_str}"
+def obs_to_msg(obs: np.ndarray | dict[str, np.ndarray | dict]) -> dict[str, Any]:
+ if isinstance(obs, np.ndarray):
+ obs = obs.tolist()
+ elif isinstance(obs, dict):
+ # Recursively convert all numpy arrays to lists
+ obs = {k: obs_to_msg(v) for k, v in obs.items()}
+ elif isinstance(obs, np.integer):
+ obs = int(obs)
+ elif isinstance(obs, np.floating):
+ obs = float(obs)
+ return obs
+
+
def msg_to_action(data: str, action_map: list[str] | dict[str, int]) -> int:
if isinstance(action_map, list):
action_map = {action: i for i, action in enumerate(action_map)}
@@ -111,17 +124,25 @@ async def websocket_endpoint(websocket: WebSocket):
while True:
try:
logging.info("Waiting for frame")
- frame: np.ndarray = await recv_queue.get()
+ out: tuple = await recv_queue.get()
+ obs, frame = out
if not isinstance(frame, np.ndarray):
logging.warning(f"Got frame of type {type(frame)}")
continue
logging.info(f"Got frame with shape {frame.shape}")
- msg = image_to_msg(frame)
- if msg is not None:
- await websocket.send_text(msg)
+ image_data = image_to_msg(frame)
+ obs_data = obs_to_msg(obs)
+
+ msg = {"observation": obs_data, "image": image_data}
+
+ if image_data is not None:
+ # await websocket.send_text(image_data)
+ await websocket.send_json(msg)
try:
- action_data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0 / fps)
+ action_data = await asyncio.wait_for(
+ websocket.receive_text(), timeout=1.0 / fps if fps > 0 else None
+ )
last_action_data = action_data
logging.info(f"Got action {action_data}, updated {last_action_data=}")
except asyncio.TimeoutError:
@@ -161,7 +182,7 @@ async def act(self, observation: Any, rendered_frame: np.ndarray | None = None)
# if rendered_frame is not None
# else "no frame"
# )
- await self.send_queue.put(rendered_frame)
+ await self.send_queue.put((observation, rendered_frame))
action = await self.recv_queue.get()
return action
diff --git a/cogment_lab/humans/static/index.html b/cogment_lab/humans/static/index.html
index bc8c60a..6f9302e 100644
--- a/cogment_lab/humans/static/index.html
+++ b/cogment_lab/humans/static/index.html
@@ -17,10 +17,13 @@
RL Env Interface
document.getElementById('start-button').addEventListener('click', function() {
if (!isConnected) {
- socket = new WebSocket("ws://" + location.host + "/ws");
+ const wsProtocol = location.protocol === 'https:' ? 'wss://' : 'ws://';
+
+ socket = new WebSocket(wsProtocol + location.host + "/ws");
socket.onmessage = function(event) {
- var landerImage = document.getElementById('env-render');
- landerImage.src = event.data;
+ let image = document.getElementById('env-render');
+ let json = JSON.parse(event.data);
+ image.src = json['image'];
};
isConnected = true;
this.style.display = 'none'; // Hide the button after it's clicked
diff --git a/tutorials/2-basic-gym.ipynb b/tutorials/2-basic-gym.ipynb
index f1a1209..023230c 100644
--- a/tutorials/2-basic-gym.ipynb
+++ b/tutorials/2-basic-gym.ipynb
@@ -49,8 +49,8 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:01.184731Z",
- "start_time": "2024-01-22T12:05:00.795408Z"
+ "end_time": "2024-02-29T12:54:06.130371Z",
+ "start_time": "2024-02-29T12:54:05.713664Z"
}
},
"id": "6a8c350e03c758af",
@@ -75,8 +75,8 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:01.188249Z",
- "start_time": "2024-01-22T12:05:01.185756Z"
+ "end_time": "2024-02-29T12:54:06.135368Z",
+ "start_time": "2024-02-29T12:54:06.131122Z"
}
},
"id": "ea3f3c96625e869",
@@ -86,7 +86,7 @@
"cell_type": "markdown",
"source": [
"Let's launch an environment. You can use any Gymnasium or PettingZoo environments. \n",
- "In this tutorial, we focus on Gymnasium, and use the `CartPole-v1` environment."
+ "In this tutorial, we focus on Gymnasium, and use the `LunarLander-v2` environment."
],
"metadata": {
"collapsed": false
@@ -110,18 +110,18 @@
"outputs": [],
"source": [
"cenv = GymEnvironment(\n",
- " env_id=\"CartPole-v1\", # Environment ID, as registered in Gymnasium\n",
+ " env_id=\"LunarLander-v2\", # Environment ID, as registered in Gymnasium\n",
" render=True, # True if we want to ever render the environment; requires pygame\n",
")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:01.191740Z",
- "start_time": "2024-01-22T12:05:01.188388Z"
+ "end_time": "2024-02-29T12:54:06.212527Z",
+ "start_time": "2024-02-29T12:54:06.133954Z"
}
},
- "id": "2adc89a51e6606ff",
+ "id": "6b13837f79d5419f",
"execution_count": 3
},
{
@@ -132,7 +132,7 @@
"metadata": {
"collapsed": false
},
- "id": "b7fafd1f0180c7e3"
+ "id": "efdc5e4ec72c6e49"
},
{
"cell_type": "code",
@@ -148,7 +148,7 @@
],
"source": [
"await cog.run_env(cenv, \n",
- " env_name=\"cartpole\", # Unique name for the environment \n",
+ " env_name=\"lunar\", # Unique name for the environment \n",
" port=9011, # Port through which the env communicates with Cogment; has to be free and unique\n",
" log_file=\"env.log\" # File to which the environment logs are written\n",
")"
@@ -156,11 +156,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:01.632Z",
- "start_time": "2024-01-22T12:05:01.192493Z"
+ "end_time": "2024-02-29T12:54:06.653102Z",
+ "start_time": "2024-02-29T12:54:06.213422Z"
}
},
- "id": "8f41d4a78c5fd3cf",
+ "id": "e5e10ff78a0868dc",
"execution_count": 4
},
{
@@ -171,7 +171,7 @@
"metadata": {
"collapsed": false
},
- "id": "40541f3a144313d6"
+ "id": "37e96d70e7916ae8"
},
{
"cell_type": "code",
@@ -180,8 +180,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Observation space: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)\n",
- "Action space: Discrete(2)\n"
+ "Observation space: Box([-1.5 -1.5 -5. -5. -3.1415927 -5.\n",
+ " -0. -0. ], [1.5 1.5 5. 5. 3.1415927 5. 1.\n",
+ " 1. ], (8,), float32)\n",
+ "Action space: Discrete(4)\n"
]
}
],
@@ -196,11 +198,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:01.660873Z",
- "start_time": "2024-01-22T12:05:01.651694Z"
+ "end_time": "2024-02-29T12:54:06.656505Z",
+ "start_time": "2024-02-29T12:54:06.652838Z"
}
},
- "id": "5420cccfd23809c5",
+ "id": "ca57199fb45e38ef",
"execution_count": 5
},
{
@@ -211,7 +213,7 @@
"metadata": {
"collapsed": false
},
- "id": "b5e7f6898ae2b5ce"
+ "id": "2ea8b53bb2d4b862"
},
{
"cell_type": "markdown",
@@ -223,7 +225,7 @@
"metadata": {
"collapsed": false
},
- "id": "2b6504629bbb5692"
+ "id": "543732e1926b1460"
},
{
"cell_type": "code",
@@ -234,11 +236,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:01.661371Z",
- "start_time": "2024-01-22T12:05:01.653933Z"
+ "end_time": "2024-02-29T12:54:06.657051Z",
+ "start_time": "2024-02-29T12:54:06.655301Z"
}
},
- "id": "3c3e2fa439f0ed4d",
+ "id": "e886a553e6edfcdd",
"execution_count": 6
},
{
@@ -249,7 +251,7 @@
"metadata": {
"collapsed": false
},
- "id": "33384582e53f70ba"
+ "id": "26743200a23d1d0b"
},
{
"cell_type": "code",
@@ -273,11 +275,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:02.173737Z",
- "start_time": "2024-01-22T12:05:01.656401Z"
+ "end_time": "2024-02-29T12:54:07.804480Z",
+ "start_time": "2024-02-29T12:54:06.657829Z"
}
},
- "id": "d7207f005843178f",
+ "id": "4fdddb4dd638f5fa",
"execution_count": 7
},
{
@@ -288,7 +290,7 @@
"metadata": {
"collapsed": false
},
- "id": "9a074e24d3049bc2"
+ "id": "28d2903e8de7de7b"
},
{
"cell_type": "code",
@@ -314,11 +316,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:02.695609Z",
- "start_time": "2024-01-22T12:05:02.173262Z"
+ "end_time": "2024-02-29T12:54:07.854839Z",
+ "start_time": "2024-02-29T12:54:07.078852Z"
}
},
- "id": "b3af97cae58a5e09",
+ "id": "ec5d4e81c7c0ed18",
"execution_count": 8
},
{
@@ -331,14 +333,14 @@
"metadata": {
"collapsed": false
},
- "id": "dbd2ff3ab5966087"
+ "id": "1ddc1bb0895db85c"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"trial_id = await cog.start_trial(\n",
- " env_name=\"cartpole\", # Name of the environment to use\n",
+ " env_name=\"lunar\", # Name of the environment to use\n",
" actor_impls={\n",
" \"gym\": \"random\", # Name of the actor to use. For Gymnasium environments, the key is always \"gym\"\n",
" },\n",
@@ -348,11 +350,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:02.741119Z",
- "start_time": "2024-01-22T12:05:02.696368Z"
+ "end_time": "2024-02-29T12:54:07.855183Z",
+ "start_time": "2024-02-29T12:54:07.493474Z"
}
},
- "id": "9091ef7c5c2c6ebb",
+ "id": "c2a1f766961e0b0c",
"execution_count": 9
},
{
@@ -363,7 +365,7 @@
"metadata": {
"collapsed": false
},
- "id": "b4c8c7f29f7eb888"
+ "id": "fd75f08e59fe2590"
},
{
"cell_type": "code",
@@ -372,11 +374,11 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Observation shape: (38, 4)\n",
- "Action shape: (38,)\n",
- "Reward shape: (38,)\n",
- "Done shape: (38,)\n",
- "Next observation shape: (38, 4)\n"
+ "Observation shape: (65, 8)\n",
+ "Action shape: (65,)\n",
+ "Reward shape: (65,)\n",
+ "Done shape: (65,)\n",
+ "Next observation shape: (65, 8)\n"
]
}
],
@@ -390,11 +392,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:02.741622Z",
- "start_time": "2024-01-22T12:05:02.737084Z"
+ "end_time": "2024-02-29T12:54:07.855518Z",
+ "start_time": "2024-02-29T12:54:07.628418Z"
}
},
- "id": "87743a3a193d2f5",
+ "id": "15c4c46964bbd015",
"execution_count": 10
},
{
@@ -405,7 +407,7 @@
"metadata": {
"collapsed": false
},
- "id": "f052de4d8eed7a1e"
+ "id": "9e48b20432a2dc62"
},
{
"cell_type": "markdown",
@@ -415,7 +417,7 @@
"metadata": {
"collapsed": false
},
- "id": "361788887641254b"
+ "id": "ae5eeb412208029e"
},
{
"cell_type": "code",
@@ -424,8 +426,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "[0 1 0 0 0 1 0 1 0 1 1 0 1 0 1 1 0 1 1 0 1 0 0 1 0 0 1 1 0 0 1 1 1 1 0 0 1\n",
- " 0]\n"
+ "[3 3 1 2 0 1 1 0 3 1 2 2 1 1 2 1 0 0 2 3 2 3 3 0 1 3 2 3 2 3 0 1 2 1 0 2 1\n",
+ " 1 2 2 2 1 1 0 1 2 2 0 1 0 2 3 1 1 0 3 3 2 1 1 2 2 0 1 0]\n"
]
}
],
@@ -435,11 +437,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:02.741901Z",
- "start_time": "2024-01-22T12:05:02.739467Z"
+ "end_time": "2024-02-29T12:54:07.855826Z",
+ "start_time": "2024-02-29T12:54:07.630661Z"
}
},
- "id": "7bac6d8ceb1ca4b9",
+ "id": "e449a5c8c7f1155b",
"execution_count": 11
},
{
@@ -450,14 +452,14 @@
"metadata": {
"collapsed": false
},
- "id": "fc1fd882c72eca1e"
+ "id": "a7e205ac71786beb"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"trial_id = await cog.start_trial(\n",
- " env_name=\"cartpole\",\n",
+ " env_name=\"lunar\",\n",
" actor_impls={\n",
" \"gym\": \"constant\",\n",
" },\n",
@@ -467,11 +469,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:02.757365Z",
- "start_time": "2024-01-22T12:05:02.741947Z"
+ "end_time": "2024-02-29T12:54:07.856144Z",
+ "start_time": "2024-02-29T12:54:07.632968Z"
}
},
- "id": "cc6e1aebec673d27",
+ "id": "3c3dd1a86783f0aa",
"execution_count": 12
},
{
@@ -482,7 +484,7 @@
"metadata": {
"collapsed": false
},
- "id": "2b726377076beee4"
+ "id": "a07f23c011f0362d"
},
{
"cell_type": "code",
@@ -491,7 +493,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "[0 0 0 0 0 0 0 0 0 0 0]\n"
+ "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n"
]
}
],
@@ -501,11 +504,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:02.775739Z",
- "start_time": "2024-01-22T12:05:02.756588Z"
+ "end_time": "2024-02-29T12:54:07.856472Z",
+ "start_time": "2024-02-29T12:54:07.669279Z"
}
},
- "id": "777ae3b4235af5e1",
+ "id": "de075c1897c7ab30",
"execution_count": 13
},
{
@@ -518,7 +521,7 @@
"metadata": {
"collapsed": false
},
- "id": "8aff1e6d3a2a6649"
+ "id": "b73061f48f4e8f7a"
},
{
"cell_type": "code",
@@ -537,8 +540,10 @@
" app_port=8000, # Port through which the web UI is accessible\n",
" cogment_port=8999, # Port through which the web UI communicates with Cogment; has to be free and unique\n",
" actions=[\n",
- " \"no-op\", # If nothing is pressed, the action is 0, i.e. the index of \"no-op\" \n",
- " \"ArrowRight\" # If the right arrow is pressed, the action is 1, i.e. the index of \"ArrowRight\"\n",
+ " \"no-op\",\n",
+ " \"ArrowRight\",\n",
+ " \"ArrowUp\",\n",
+ " \"ArrowLeft\"\n",
" ],\n",
" log_file=\"human.log\"\n",
")"
@@ -546,11 +551,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:03.286519Z",
- "start_time": "2024-01-22T12:05:02.758783Z"
+ "end_time": "2024-02-29T12:54:08.190822Z",
+ "start_time": "2024-02-29T12:54:07.671633Z"
}
},
- "id": "899fa9b018b34048",
+ "id": "4389246fa64ef1cd",
"execution_count": 14
},
{
@@ -561,14 +566,14 @@
"metadata": {
"collapsed": false
},
- "id": "acaa190815aa5b57"
+ "id": "98544818ce7647de"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"trial_id = await cog.start_trial(\n",
- " env_name=\"cartpole\",\n",
+ " env_name=\"lunar\",\n",
" session_config={\"render\": True}, # Tell cogment that we want to use the renders of the environment\n",
" actor_impls={\n",
" \"gym\": \"web_ui\",\n",
@@ -580,11 +585,11 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:10.569759Z",
- "start_time": "2024-01-22T12:05:03.287684Z"
+ "end_time": "2024-02-29T12:54:27.987726Z",
+ "start_time": "2024-02-29T12:54:08.189816Z"
}
},
- "id": "dac6af7393d0f866",
+ "id": "f201c3efb8fcaec1",
"execution_count": 15
},
{
@@ -592,9 +597,9 @@
"source": [
"You may see that the cell above is still running. This is because Cogment is waiting for the human - you!\n",
"\n",
- "Open your browser, and go to `http://localhost:8000`. Click the Start button, and then press (and depress) your right arrow to try to balance the cartpole.\n",
+ "Open your browser, and go to `http://localhost:8000`. Click the Start button, and then use the arrow keys to try and land the lunar lander.\n",
"\n",
- "After it inevitably falls (or times out, if you're good), go back here and check your result below:"
+ "After it inevitably crashes (or lands, if you're good), go back here and check your result below:"
],
"metadata": {
"collapsed": false
@@ -608,7 +613,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Your reward is 13.0\n"
+ "Your reward is 305.34600830078125\n"
]
}
],
@@ -618,8 +623,8 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:10.573865Z",
- "start_time": "2024-01-22T12:05:10.570277Z"
+ "end_time": "2024-02-29T12:54:27.992004Z",
+ "start_time": "2024-02-29T12:54:27.988091Z"
}
},
"id": "8a9e4fc3efc77d72",
@@ -646,8 +651,8 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "end_time": "2024-01-22T12:05:12.609152Z",
- "start_time": "2024-01-22T12:05:10.572592Z"
+ "end_time": "2024-02-29T12:54:29.014194Z",
+ "start_time": "2024-02-29T12:54:27.990580Z"
}
},
"id": "4447e862eb37c7c9",