From 25fdc7cc52b2f4619e3e458e250547e503086c0d Mon Sep 17 00:00:00 2001 From: Yi Wan Date: Mon, 12 Feb 2024 20:20:47 -0800 Subject: [PATCH] frozen lake Summary: A tutorial on the frozen lake environment, which involves transforming observation indices to one-hot representations. Reviewed By: rodrigodesalvobraz Differential Revision: D53691908 fbshipit-source-id: f2215ac499a5ef58a55c1fea946d9061c1fb64be --- pearl/tutorials/frozen_lake/demo.ipynb | 238 +++++++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 pearl/tutorials/frozen_lake/demo.ipynb diff --git a/pearl/tutorials/frozen_lake/demo.ipynb b/pearl/tutorials/frozen_lake/demo.ipynb new file mode 100644 index 00000000..62484a8a --- /dev/null +++ b/pearl/tutorials/frozen_lake/demo.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FrozenLake-v1\n", + "This example shows how to use DQN to solve the FrozenLake-v1 environment from gymasium. This environment has observations as indices (tabular observation). On the other hand, Pearl assumes that states are represented as vectors. In what follows, we show how to use Pearl's OneHotObservationsFromDiscrete wrapper to convert observations to their one-hot representations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8NNfwWXGvn_o", + "output": { + "id": 383783884102102, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation\n", + "If you haven't installed Pearl, please make sure you install Pearl with the following cell. Otherwise, you can skip the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1uLHbYlegKX-" + }, + "outputs": [], + "source": [ + "%pip uninstall Pearl -y\n", + "%rm -rf Pearl\n", + "!git clone https://github.com/facebookresearch/Pearl.git\n", + "%cd Pearl\n", + "%pip install .\n", + "%cd .." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import Modules" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "id": "vcb70ZC_h3OA" + }, + "outputs": [], + "source": [ + "from pearl.utils.functional_utils.experimentation.set_seed import set_seed\n", + "from pearl.policy_learners.sequential_decision_making.deep_q_learning import DeepQLearning\n", + "from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import FIFOOffPolicyReplayBuffer\n", + "from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning\n", + "from pearl.pearl_agent import PearlAgent\n", + "from pearl.utils.instantiations.environments.gym_environment import GymEnvironment\n", + "from pearl.utils.instantiations.environments.environments import (\n", + " OneHotObservationsFromDiscrete,\n", + ")\n", + "from pearl.utils.instantiations.spaces.discrete import DiscreteSpace\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from pearl.action_representation_modules.one_hot_action_representation_module import (\n", + " OneHotActionTensorRepresentationModule,\n", + ")\n", + "\n", + "set_seed(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Vanilla DQN " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kulkpFAvnOQx" + }, + "outputs": [], + "source": [ + "number_of_steps = 20000\n", + "record_period = 400\n", + "\n", + "\"\"\"\n", + "This test is checking if DQN will eventually solve FrozenLake-v1\n", + "whose observations need to be wrapped in a one-hot representation.\n", + "\"\"\"\n", + "env = OneHotObservationsFromDiscrete(\n", + " GymEnvironment(\n", + " \"FrozenLake-v1\", is_slippery=False, map_name=\"4x4\",\n", + " )\n", + ")\n", + "\n", + "action_representation_module = OneHotActionTensorRepresentationModule(\n", + " max_number_actions= env.action_space.n,\n", + ")\n", + "\n", + "assert isinstance(env.action_space, DiscreteSpace)\n", + "state_dim = env.observation_space.n\n", + "agent = PearlAgent(\n", + " policy_learner=DeepQLearning(\n", + " state_dim=state_dim,\n", + " action_space=env.action_space,\n", + " hidden_dims=[64, 64],\n", + " training_rounds=1,\n", + " action_representation_module=action_representation_module\n", + " ),\n", + " replay_buffer=FIFOOffPolicyReplayBuffer(1000),\n", + ")\n", + "\n", + "info = online_learning(\n", + " agent=agent,\n", + " env=env,\n", + " number_of_steps=number_of_steps,\n", + " print_every_x_steps=100,\n", + " record_period=record_period,\n", + " learn_after_episode=False,\n", + ")\n", + "torch.save(info[\"return\"], \"DQN-return.pt\")\n", + "plt.plot(record_period * np.arange(len(info[\"return\"])), info[\"return\"], label=\"DQN\")\n", + "plt.xlabel(\"steps\")\n", + "plt.ylabel(\"return\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "custom": { + "cells": [], + "metadata": { + "custom": { + "cells": [], + "metadata": { + "custom": { + "cells": [], + "metadata": { + "custom": { + "cells": [], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "include_colab_link": true, + "provenance": [] + }, + "fileHeader": "", + "fileUid": "4316417e-7688-45f2-a94f-24148bfc425e", + "isAdHoc": false, + "kernelspec": { + "display_name": "pearl (local)", + "language": "python", + "name": "pearl_local" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 + }, + "fileHeader": "", + "fileUid": "1158a851-91bb-437e-a391-aba92448f600", + "indentAmount": 2, + "isAdHoc": false, + "language_info": { + "name": "plaintext" + } + }, + "nbformat": 4, + "nbformat_minor": 2 + }, + "fileHeader": "", + "fileUid": "ddf9fa29-09d7-404d-bc1b-62a580952524", + "indentAmount": 2, + "isAdHoc": false, + "language_info": { + "name": "plaintext" + } + }, + "nbformat": 4, + "nbformat_minor": 2 + }, + "fileHeader": "", + "fileUid": "e751f6fa-be9e-4f88-9fef-36812551b013", + "indentAmount": 2, + "isAdHoc": false, + "kernelspec": { + "display_name": "pearl2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 + }, + "indentAmount": 2 + }, + "nbformat": 4, + "nbformat_minor": 2 +}