-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: A tutorial on the frozen lake environment, which involves transforming observation indices to one-hot representations. Reviewed By: rodrigodesalvobraz Differential Revision: D53691908 fbshipit-source-id: f2215ac499a5ef58a55c1fea946d9061c1fb64be
- Loading branch information
1 parent
108fb7d
commit 25fdc7c
Showing
1 changed file
with
238 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## FrozenLake-v1\n", | ||
"This example shows how to use DQN to solve the FrozenLake-v1 environment from gymasium. This environment has observations as indices (tabular observation). On the other hand, Pearl assumes that states are represented as vectors. In what follows, we show how to use Pearl's OneHotObservationsFromDiscrete wrapper to convert observations to their one-hot representations." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "8NNfwWXGvn_o", | ||
"output": { | ||
"id": 383783884102102, | ||
"loadingStatus": "loaded" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Installation\n", | ||
"If you haven't installed Pearl, please make sure you install Pearl with the following cell. Otherwise, you can skip the cell below." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "1uLHbYlegKX-" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip uninstall Pearl -y\n", | ||
"%rm -rf Pearl\n", | ||
"!git clone https://github.com/facebookresearch/Pearl.git\n", | ||
"%cd Pearl\n", | ||
"%pip install .\n", | ||
"%cd .." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Import Modules" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 46, | ||
"metadata": { | ||
"id": "vcb70ZC_h3OA" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from pearl.utils.functional_utils.experimentation.set_seed import set_seed\n", | ||
"from pearl.policy_learners.sequential_decision_making.deep_q_learning import DeepQLearning\n", | ||
"from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import FIFOOffPolicyReplayBuffer\n", | ||
"from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning\n", | ||
"from pearl.pearl_agent import PearlAgent\n", | ||
"from pearl.utils.instantiations.environments.gym_environment import GymEnvironment\n", | ||
"from pearl.utils.instantiations.environments.environments import (\n", | ||
" OneHotObservationsFromDiscrete,\n", | ||
")\n", | ||
"from pearl.utils.instantiations.spaces.discrete import DiscreteSpace\n", | ||
"import torch\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"from pearl.action_representation_modules.one_hot_action_representation_module import (\n", | ||
" OneHotActionTensorRepresentationModule,\n", | ||
")\n", | ||
"\n", | ||
"set_seed(0)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Vanilla DQN " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "kulkpFAvnOQx" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"number_of_steps = 20000\n", | ||
"record_period = 400\n", | ||
"\n", | ||
"\"\"\"\n", | ||
"This test is checking if DQN will eventually solve FrozenLake-v1\n", | ||
"whose observations need to be wrapped in a one-hot representation.\n", | ||
"\"\"\"\n", | ||
"env = OneHotObservationsFromDiscrete(\n", | ||
" GymEnvironment(\n", | ||
" \"FrozenLake-v1\", is_slippery=False, map_name=\"4x4\",\n", | ||
" )\n", | ||
")\n", | ||
"\n", | ||
"action_representation_module = OneHotActionTensorRepresentationModule(\n", | ||
" max_number_actions= env.action_space.n,\n", | ||
")\n", | ||
"\n", | ||
"assert isinstance(env.action_space, DiscreteSpace)\n", | ||
"state_dim = env.observation_space.n\n", | ||
"agent = PearlAgent(\n", | ||
" policy_learner=DeepQLearning(\n", | ||
" state_dim=state_dim,\n", | ||
" action_space=env.action_space,\n", | ||
" hidden_dims=[64, 64],\n", | ||
" training_rounds=1,\n", | ||
" action_representation_module=action_representation_module\n", | ||
" ),\n", | ||
" replay_buffer=FIFOOffPolicyReplayBuffer(1000),\n", | ||
")\n", | ||
"\n", | ||
"info = online_learning(\n", | ||
" agent=agent,\n", | ||
" env=env,\n", | ||
" number_of_steps=number_of_steps,\n", | ||
" print_every_x_steps=100,\n", | ||
" record_period=record_period,\n", | ||
" learn_after_episode=False,\n", | ||
")\n", | ||
"torch.save(info[\"return\"], \"DQN-return.pt\")\n", | ||
"plt.plot(record_period * np.arange(len(info[\"return\"])), info[\"return\"], label=\"DQN\")\n", | ||
"plt.xlabel(\"steps\")\n", | ||
"plt.ylabel(\"return\")\n", | ||
"plt.legend()\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"custom": { | ||
"cells": [], | ||
"metadata": { | ||
"custom": { | ||
"cells": [], | ||
"metadata": { | ||
"custom": { | ||
"cells": [], | ||
"metadata": { | ||
"custom": { | ||
"cells": [], | ||
"metadata": { | ||
"accelerator": "GPU", | ||
"colab": { | ||
"gpuType": "T4", | ||
"include_colab_link": true, | ||
"provenance": [] | ||
}, | ||
"fileHeader": "", | ||
"fileUid": "4316417e-7688-45f2-a94f-24148bfc425e", | ||
"isAdHoc": false, | ||
"kernelspec": { | ||
"display_name": "pearl (local)", | ||
"language": "python", | ||
"name": "pearl_local" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
}, | ||
"fileHeader": "", | ||
"fileUid": "1158a851-91bb-437e-a391-aba92448f600", | ||
"indentAmount": 2, | ||
"isAdHoc": false, | ||
"language_info": { | ||
"name": "plaintext" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
}, | ||
"fileHeader": "", | ||
"fileUid": "ddf9fa29-09d7-404d-bc1b-62a580952524", | ||
"indentAmount": 2, | ||
"isAdHoc": false, | ||
"language_info": { | ||
"name": "plaintext" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
}, | ||
"fileHeader": "", | ||
"fileUid": "e751f6fa-be9e-4f88-9fef-36812551b013", | ||
"indentAmount": 2, | ||
"isAdHoc": false, | ||
"kernelspec": { | ||
"display_name": "pearl2", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.18" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
}, | ||
"indentAmount": 2 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |