diff --git a/tutorials/sequential_decision_making/Implicit_Q_learning.ipynb b/tutorials/sequential_decision_making/Implicit_Q_learning.ipynb new file mode 100644 index 00000000..145484d6 --- /dev/null +++ b/tutorials/sequential_decision_making/Implicit_Q_learning.ipynb @@ -0,0 +1,639 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## Using Implicit Q learning (an offline RL algorithm) in Pearl.\n", + "\n", + "Here is a [better rendering](https://nbviewer.org/github/facebookresearch/Pearl/blob/main/tutorials/sequential_decision_making/Implicit_Q_learning.ipynb) of this notebook on [nbviewer](https://nbviewer.org/)\n", + "\n", + "- The purpose of this tutorial is to illustrate how users can use Pearl's implementation of Implicit Q-learning (IQL), an offline RL algorithm.\n", + "\n", + "- This example illustrates offline learning for continuous control using\n", + "offline data collected from Open AI Gym's HalfCheetah environment.\n" + ], + "metadata": { + "id": "PBWPlSEq_fBf" + } + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "0Fa4tQSG_YoI" + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Pearl Installation\n", + "\n", + "If you haven't installed Pearl, please make sure you install Pearl with the following cell. Otherwise, you can skip the cell below." + ], + "metadata": { + "id": "VLL6NfNABgQv" + } + }, + { + "cell_type": "code", + "source": [ + "# Pearl installation from github. This install also includes PyTorch, Gym and Matplotlib\n", + "\n", + "%pip uninstall Pearl -y\n", + "%rm -rf Pearl\n", + "!git clone https://github.com/facebookresearch/Pearl.git\n", + "%cd Pearl\n", + "%pip install .\n", + "%cd .." + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gZyj_w4sBg7L", + "outputId": "66f23017-ff4c-4127-bc09-3259f573a1c6" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[33mWARNING: Skipping Pearl as it is not installed.\u001b[0m\u001b[33m\n", + "\u001b[0mCloning into 'Pearl'...\n", + "remote: Enumerating objects: 5615, done.\u001b[K\n", + "remote: Counting objects: 100% (1824/1824), done.\u001b[K\n", + "remote: Compressing objects: 100% (613/613), done.\u001b[K\n", + "remote: Total 5615 (delta 1388), reused 1563 (delta 1199), pack-reused 3791\u001b[K\n", + "Receiving objects: 100% (5615/5615), 53.75 MiB | 21.57 MiB/s, done.\n", + "Resolving deltas: 100% (3715/3715), done.\n", + "/content/Pearl\n", + "Processing /content/Pearl\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: gym in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (0.25.2)\n", + "Collecting gymnasium[accept-rom-license,atari,mujoco] (from Pearl==0.1.0)\n", + " Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m953.9/953.9 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (1.25.2)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (3.7.1)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (1.5.3)\n", + "Collecting parameterized (from Pearl==0.1.0)\n", + " Downloading parameterized-0.9.0-py2.py3-none-any.whl (20 kB)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (2.31.0)\n", + "Collecting mujoco (from Pearl==0.1.0)\n", + " Downloading mujoco-3.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m74.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (2.2.1+cu121)\n", + "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (0.17.1+cu121)\n", + "Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (2.2.1+cu121)\n", + "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gym->Pearl==0.1.0) (2.2.1)\n", + "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym->Pearl==0.1.0) (0.0.8)\n", + "Requirement already satisfied: typing-extensions>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (4.10.0)\n", + "Collecting farama-notifications>=0.0.1 (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0)\n", + " Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\n", + "Collecting autorom[accept-rom-license]~=0.4.2 (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0)\n", + " Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB)\n", + "Requirement already satisfied: imageio>=2.14.1 in /usr/local/lib/python3.10/dist-packages (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (2.31.6)\n", + "Collecting shimmy[atari]<1.0,>=0.1.0 (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0)\n", + " Downloading Shimmy-0.2.1-py3-none-any.whl (25 kB)\n", + "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from mujoco->Pearl==0.1.0) (1.4.0)\n", + "Requirement already satisfied: etils[epath] in /usr/local/lib/python3.10/dist-packages (from mujoco->Pearl==0.1.0) (1.7.0)\n", + "Collecting glfw (from mujoco->Pearl==0.1.0)\n", + " Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.8/211.8 kB\u001b[0m \u001b[31m26.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pyopengl in /usr/local/lib/python3.10/dist-packages (from mujoco->Pearl==0.1.0) (3.1.7)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (1.2.0)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (4.50.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (1.4.5)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (24.0)\n", + "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (9.4.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (3.1.2)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->Pearl==0.1.0) (2023.4)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (2024.2.2)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (3.13.1)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (3.1.3)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (2023.6.0)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->Pearl==0.1.0)\n", + " Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m37.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->Pearl==0.1.0)\n", + " Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m55.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->Pearl==0.1.0)\n", + " Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m79.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch->Pearl==0.1.0)\n", + " Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m731.7/731.7 MB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch->Pearl==0.1.0)\n", + " Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch->Pearl==0.1.0)\n", + " Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch->Pearl==0.1.0)\n", + " Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch->Pearl==0.1.0)\n", + " Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch->Pearl==0.1.0)\n", + " Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-nccl-cu12==2.19.3 (from torch->Pearl==0.1.0)\n", + " Downloading nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m166.0/166.0 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch->Pearl==0.1.0)\n", + " Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m14.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (2.2.0)\n", + "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch->Pearl==0.1.0)\n", + " Downloading nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from autorom[accept-rom-license]~=0.4.2->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (8.1.7)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from autorom[accept-rom-license]~=0.4.2->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (4.66.2)\n", + "Collecting AutoROM.accept-rom-license (from autorom[accept-rom-license]~=0.4.2->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0)\n", + " Downloading AutoROM.accept-rom-license-0.6.1.tar.gz (434 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m434.7/434.7 kB\u001b[0m \u001b[31m44.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->Pearl==0.1.0) (1.16.0)\n", + "Collecting ale-py~=0.8.1 (from shimmy[atari]<1.0,>=0.1.0->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0)\n", + " Downloading ale_py-0.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m81.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[epath]->mujoco->Pearl==0.1.0) (6.3.2)\n", + "Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[epath]->mujoco->Pearl==0.1.0) (3.18.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->Pearl==0.1.0) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->Pearl==0.1.0) (1.3.0)\n", + "Building wheels for collected packages: Pearl, AutoROM.accept-rom-license\n", + " Building wheel for Pearl (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for Pearl: filename=Pearl-0.1.0-py3-none-any.whl size=211430 sha256=e4d88588b2376e35794984d9e88eebfdf7cfdeb772e9b8420e41193499eb4342\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-5nxlxw_t/wheels/83/80/1d/d9211ba70ee392341daf21a07252739e0cb2af9f95439a28cd\n", + " Building wheel for AutoROM.accept-rom-license (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for AutoROM.accept-rom-license: filename=AutoROM.accept_rom_license-0.6.1-py3-none-any.whl size=446659 sha256=15077f3d72b6dbbe7e9d1afa8e2148a5cc9e20ab6348782facfddb6ccc5807b7\n", + " Stored in directory: /root/.cache/pip/wheels/6b/1b/ef/a43ff1a2f1736d5711faa1ba4c1f61be1131b8899e6a057811\n", + "Successfully built Pearl AutoROM.accept-rom-license\n", + "Installing collected packages: glfw, farama-notifications, parameterized, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, gymnasium, ale-py, shimmy, nvidia-cusparse-cu12, nvidia-cudnn-cu12, AutoROM.accept-rom-license, autorom, nvidia-cusolver-cu12, mujoco, Pearl\n", + "Successfully installed AutoROM.accept-rom-license-0.6.1 Pearl-0.1.0 ale-py-0.8.1 autorom-0.4.2 farama-notifications-0.0.4 glfw-2.7.0 gymnasium-0.29.1 mujoco-3.1.3 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.4.99 nvidia-nvtx-cu12-12.1.105 parameterized-0.9.0 shimmy-0.2.1\n", + "/content\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import pickle\n", + "import torch\n", + "\n", + "\n", + "from pearl.pearl_agent import PearlAgent\n", + "from pearl.utils.functional_utils.experimentation.set_seed import set_seed\n", + "from pearl.utils.instantiations.environments.gym_environment import GymEnvironment\n", + "from pearl.neural_networks.sequential_decision_making.actor_networks import VanillaContinuousActorNetwork\n", + "from pearl.policy_learners.sequential_decision_making.implicit_q_learning import ImplicitQLearning\n", + "\n", + "from pearl.utils.functional_utils.experimentation.create_offline_data import (\n", + " get_data_collection_agent_returns,\n", + ")\n", + "\n", + "from pearl.utils.functional_utils.train_and_eval.offline_learning_and_evaluation import (\n", + " get_offline_data_in_buffer,\n", + " offline_evaluation,\n", + " offline_learning,\n", + ")\n", + "\n", + "set_seed(0)" + ], + "metadata": { + "id": "YoD1X9oyBnzD" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Specify some environment details" + ], + "metadata": { + "id": "I4CEVZn9Bv3g" + } + }, + { + "cell_type": "code", + "source": [ + "experiment_seed = 100\n", + "\n", + "# We choose a continuous control environment from MuJoCo called 'Half-Cheetah'.\n", + "env_name = \"HalfCheetah-v4\"\n", + "env = GymEnvironment(env_name)\n", + "action_space = env.action_space\n", + "is_action_continuous = True" + ], + "metadata": { + "id": "tciCqt6WBumg", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "0b495d4c-5ae0-4c6f-a50e-9e989acb7b90" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", + " and should_run_async(code)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Download offline data\n", + "\n", + "- We have offline data at https://github.com/facebookresearch/Pearl/tree/gh-pages/data which will be used for this tutorial.\n", + "\n", + "- This dataset was collected using the replay buffer of an experiment where soft-actor critic (SAC) algorithm was used for policy learning, along with a large entropy parameter for exploration. Users can think of it as a 'medium replay' dataset in D4RL (https://github.com/Farama-Foundation/D4RL).\n", + "\n", + "- Note: We have not integrated Pearl with D4RL as it is being deprecated and a new library called Minari is being developed. We do plan to integrate Pearl with Minari at a later time." + ], + "metadata": { + "id": "IykiuKhUCC3r" + } + }, + { + "cell_type": "code", + "source": [ + "import requests\n", + "\n", + "# Download the dataset of transition tuples from the github url and store in a local file\n", + "url = \"https://github.com/facebookresearch/Pearl/raw/gh-pages/data/offline_rl_data/HalfCheetah/offline_raw_transitions_dict_small_2.pt\"\n", + "filename = \"offline_raw_transitions_dict_small_2.pt\" # local file with the dataset of transition tuples\n", + "response = requests.get(url)\n", + "with open(filename, \"wb\") as f:\n", + " f.write(response.content)" + ], + "metadata": { + "id": "X4X4YspnDlUD" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "cwd = os.getcwd() # current working directory\n", + "data_path = cwd + '/' + filename # change this in case you have a file with the dataset of transition tuples already stored in a local path.\n", + "\n", + "# The default device where offline data replay buffer is stored is cpu; see the `get_offline_data_in_buffer` for device management\n", + "offline_data_replay_buffer = get_offline_data_in_buffer(\n", + " is_action_continuous=is_action_continuous,\n", + " url=None,\n", + " data_path=data_path, # path to local file which contains the offline dataset\n", + ")" + ], + "metadata": { + "id": "SWGnW4-0cw0D" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Set up an IQL agent\n", + "\n", + "- Note that here `env` and `action_space` are for the HalfCheetah environment as set above.\n" + ], + "metadata": { + "id": "eD0Q6NwQngxf" + } + }, + { + "cell_type": "code", + "source": [ + "offline_agent = PearlAgent(\n", + " policy_learner=ImplicitQLearning(\n", + " state_dim=env.observation_space.shape[0],\n", + " action_space=action_space,\n", + " actor_hidden_dims=[256, 256],\n", + " critic_hidden_dims=[256, 256],\n", + " value_critic_hidden_dims=[256, 256],\n", + " actor_network_type=VanillaContinuousActorNetwork,\n", + " value_critic_learning_rate=1e-3,\n", + " actor_learning_rate=3e-4,\n", + " critic_learning_rate=1e-4,\n", + " critic_soft_update_tau=0.05,\n", + " training_rounds=2,\n", + " batch_size=256,\n", + " expectile=0.75,\n", + " temperature_advantage_weighted_regression=3,\n", + " ),\n", + ")" + ], + "metadata": { + "id": "2nHPHmFenfoD" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Offline learning" + ], + "metadata": { + "id": "UJ56YQqbpQ-U" + } + }, + { + "cell_type": "code", + "source": [ + "# Number of training epochs\n", + "training_epochs = 20000\n", + "\n", + "# Use the `offline_learning` utility function in Pearl to train an offline RL agent using offline data\n", + "offline_learning(\n", + " offline_agent=offline_agent,\n", + " data_buffer=offline_data_replay_buffer, # replay buffer created using the offline data\n", + " training_epochs=training_epochs,\n", + " seed=experiment_seed,\n", + ")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "oPe4Sux2pA_H", + "outputId": "a7143493-6f54-406f-ca6c-fc8baaba0a10" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "training epoch 0 training loss {'value_loss': 0.44493526220321655, 'actor_loss': 1.1561717987060547, 'critic_loss': 5.130976676940918}\n", + "training epoch 500 training loss {'value_loss': 0.1409292221069336, 'actor_loss': 0.5989271998405457, 'critic_loss': 10.62370491027832}\n", + "training epoch 1000 training loss {'value_loss': 0.5230052471160889, 'actor_loss': 0.9807262420654297, 'critic_loss': 24.416900634765625}\n", + "training epoch 1500 training loss {'value_loss': 0.9530619382858276, 'actor_loss': 3.278916597366333, 'critic_loss': 29.79534149169922}\n", + "training epoch 2000 training loss {'value_loss': 1.614856243133545, 'actor_loss': 6.346199035644531, 'critic_loss': 43.3765869140625}\n", + "training epoch 2500 training loss {'value_loss': 2.005828619003296, 'actor_loss': 7.089072227478027, 'critic_loss': 59.126861572265625}\n", + "training epoch 3000 training loss {'value_loss': 2.0187668800354004, 'actor_loss': 3.204172134399414, 'critic_loss': 48.85820770263672}\n", + "training epoch 3500 training loss {'value_loss': 3.2361795902252197, 'actor_loss': 1.1803092956542969, 'critic_loss': 61.56980895996094}\n", + "training epoch 4000 training loss {'value_loss': 2.8879613876342773, 'actor_loss': 6.797780990600586, 'critic_loss': 41.75172805786133}\n", + "training epoch 4500 training loss {'value_loss': 3.4644596576690674, 'actor_loss': 5.647995471954346, 'critic_loss': 64.77790832519531}\n", + "training epoch 5000 training loss {'value_loss': 3.648473024368286, 'actor_loss': 3.0125679969787598, 'critic_loss': 66.62677001953125}\n", + "training epoch 5500 training loss {'value_loss': 3.801520347595215, 'actor_loss': 6.153298377990723, 'critic_loss': 61.24081039428711}\n", + "training epoch 6000 training loss {'value_loss': 4.654277801513672, 'actor_loss': 3.098684787750244, 'critic_loss': 75.80091094970703}\n", + "training epoch 6500 training loss {'value_loss': 4.773356914520264, 'actor_loss': 5.105808258056641, 'critic_loss': 244.9698486328125}\n", + "training epoch 7000 training loss {'value_loss': 5.341713905334473, 'actor_loss': 4.7835845947265625, 'critic_loss': 126.96758270263672}\n", + "training epoch 7500 training loss {'value_loss': 4.727663516998291, 'actor_loss': 6.566059112548828, 'critic_loss': 51.88398742675781}\n", + "training epoch 8000 training loss {'value_loss': 5.532264709472656, 'actor_loss': 5.600111961364746, 'critic_loss': 72.99192810058594}\n", + "training epoch 8500 training loss {'value_loss': 5.432359218597412, 'actor_loss': 4.592930793762207, 'critic_loss': 60.98143005371094}\n", + "training epoch 9000 training loss {'value_loss': 5.523667812347412, 'actor_loss': 5.735751152038574, 'critic_loss': 433.49212646484375}\n", + "training epoch 9500 training loss {'value_loss': 5.216910362243652, 'actor_loss': 4.710633754730225, 'critic_loss': 83.78570556640625}\n", + "training epoch 10000 training loss {'value_loss': 5.5933661460876465, 'actor_loss': 6.106831073760986, 'critic_loss': 147.11474609375}\n", + "training epoch 10500 training loss {'value_loss': 8.451723098754883, 'actor_loss': 1.621864676475525, 'critic_loss': 80.56407165527344}\n", + "training epoch 11000 training loss {'value_loss': 5.930197715759277, 'actor_loss': 6.696375846862793, 'critic_loss': 51.827354431152344}\n", + "training epoch 11500 training loss {'value_loss': 6.042293071746826, 'actor_loss': 4.567534446716309, 'critic_loss': 49.92619323730469}\n", + "training epoch 12000 training loss {'value_loss': 6.438072204589844, 'actor_loss': 6.379092693328857, 'critic_loss': 61.6492919921875}\n", + "training epoch 12500 training loss {'value_loss': 7.940831661224365, 'actor_loss': 8.604472160339355, 'critic_loss': 75.86102294921875}\n", + "training epoch 13000 training loss {'value_loss': 5.591019630432129, 'actor_loss': 5.480114936828613, 'critic_loss': 48.68461990356445}\n", + "training epoch 13500 training loss {'value_loss': 7.129138946533203, 'actor_loss': 7.606431007385254, 'critic_loss': 373.935546875}\n", + "training epoch 14000 training loss {'value_loss': 6.827956199645996, 'actor_loss': 4.753203392028809, 'critic_loss': 53.25279235839844}\n", + "training epoch 14500 training loss {'value_loss': 6.406070709228516, 'actor_loss': 4.576913833618164, 'critic_loss': 312.50872802734375}\n", + "training epoch 15000 training loss {'value_loss': 5.100269317626953, 'actor_loss': 5.159261703491211, 'critic_loss': 41.06599426269531}\n", + "training epoch 15500 training loss {'value_loss': 5.180943965911865, 'actor_loss': 3.508758306503296, 'critic_loss': 117.24136352539062}\n", + "training epoch 16000 training loss {'value_loss': 5.463344573974609, 'actor_loss': 6.216361999511719, 'critic_loss': 43.10062026977539}\n", + "training epoch 16500 training loss {'value_loss': 5.401617527008057, 'actor_loss': 5.747153282165527, 'critic_loss': 55.972713470458984}\n", + "training epoch 17000 training loss {'value_loss': 3.9423627853393555, 'actor_loss': 5.501956462860107, 'critic_loss': 37.53641891479492}\n", + "training epoch 17500 training loss {'value_loss': 5.858860015869141, 'actor_loss': 5.931163311004639, 'critic_loss': 52.11505889892578}\n", + "training epoch 18000 training loss {'value_loss': 4.535668849945068, 'actor_loss': 6.64661979675293, 'critic_loss': 39.68230056762695}\n", + "training epoch 18500 training loss {'value_loss': 5.359783172607422, 'actor_loss': 4.801802158355713, 'critic_loss': 48.881675720214844}\n", + "training epoch 19000 training loss {'value_loss': 4.827009201049805, 'actor_loss': 5.660262584686279, 'critic_loss': 38.47284698486328}\n", + "training epoch 19500 training loss {'value_loss': 5.641027450561523, 'actor_loss': 4.343751430511475, 'critic_loss': 49.17301559448242}\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Offline evaluation" + ], + "metadata": { + "id": "anb1UsbspTNE" + } + }, + { + "cell_type": "code", + "source": [ + "# Use the `offline_evaluation` utility function in Pearl to evaluate the trained agent by interacting with the environment.\n", + "\n", + "offline_evaluation_returns = offline_evaluation(\n", + " offline_agent=offline_agent,\n", + " env=env,\n", + " number_of_episodes=50,\n", + " seed=experiment_seed,\n", + ")\n", + "\n", + "# mean evaluation returns of the offline agent\n", + "avg_offline_agent_returns = torch.mean(torch.tensor(offline_evaluation_returns))\n", + "print(f\"average returns of the offline agent {avg_offline_agent_returns}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wAu_fL6EpXfL", + "outputId": "97718810-4cc1-4f52-af2e-45b6257aef2d" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "episode 49, return=244.23391946865013average returns of the offline agent 230.69460202311308\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Getting normalized scores (typical for benchmarking offline RL algorithms)\n" + ], + "metadata": { + "id": "gNddjWLf1EOk" + } + }, + { + "cell_type": "code", + "source": [ + "# Step 1: get episodic returns of a random agent using the 'returns_random_agent.pickle' file on github.\n", + "# You can also do this by initializing an offline agent randomly and interacting with the environment.\n", + "\n", + "\n", + "# URL of the file on GitHub\n", + "random_returns_url = \"https://github.com/facebookresearch/Pearl/raw/gh-pages/data/offline_rl_data/HalfCheetah/returns_random_agent.pickle\"\n", + "\n", + "# Download the file\n", + "response = requests.get(random_returns_url)\n", + "returns = response.content\n", + "\n", + "# Load the data from the file\n", + "with open('random_agent_returns.pkl', 'wb') as f:\n", + " f.write(returns)\n", + "with open('random_agent_returns.pkl', 'rb') as f:\n", + " random_agent_returns = pickle.load(f)\n", + "\n", + "\n", + "avg_return_random_agent = torch.mean(torch.tensor(random_agent_returns)) # mean returns of a random agent\n", + "print(f\"average returns of a random agent {avg_return_random_agent}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xvpt-b76zKI4", + "outputId": "069b3730-96d0-4660-dde7-34f17338c6f3" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "average returns of a random agent -426.930167355092\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Step 2: get training returns in the data (i.e. episodic returns of the data collection agent).\n", + "\n", + "# The `get_data_collection_agent_returns` function computes the episodic returns from the offline data of transition tuples\n", + "# Note 1: We implicitly assume that the offline data was collected in an episodic manner\n", + "# Note 2: The `data_path` points to the local file with offline data. Recall that we set data_path = cwd + '/' + filename, where filename = \"offline_raw_transitions_dict_small_2.pt\"\n", + "data_collection_agent_returns = get_data_collection_agent_returns(\n", + " data_path=data_path\n", + ")\n", + "\n", + "# average trajectory returns in the dataset\n", + "avg_return_data_collection_agent = torch.mean(\n", + " torch.tensor(data_collection_agent_returns)\n", + ")\n", + "print(\n", + " f\"average returns of the data collection agent {avg_return_data_collection_agent}\"\n", + ")\n", + "\n", + "\n", + "max_return_data_collection_agent = torch.max(torch.tensor(data_collection_agent_returns)) # maximum trajectory returns in the dataset\n", + "print(f\"maximum returns of the data collection agent {max_return_data_collection_agent}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NkgL8dVT5okl", + "outputId": "1b218ae1-aad5-403e-b666-5344d5b18fed" + }, + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "getting returns of the data collection agent agent\n", + "using offline training data in /content/offline_raw_transitions_dict_small_2.pt to stitch trajectories and compute returns\n", + "average returns of the data collection agent 490.7219223530043\n", + "maximum returns of the data collection agent 1878.2492669075727\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# The following is one way to define the normalized score, which is a proxy for\n", + "# how good the offline learning algorithm is as compared to the policy that was used to collect data.\n", + "\n", + "normalized_score = (avg_offline_agent_returns - avg_return_random_agent) / (\n", + " avg_return_data_collection_agent - avg_return_random_agent\n", + ")\n", + "\n", + "print(f\"normalized score {normalized_score}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "j3rVgiB76E8V", + "outputId": "3517c80e-64cd-4b5e-9144-d254db799402" + }, + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "normalized score 0.716638448006362\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Note 1: A normalized score close to 1 or greater than 1 indicates good performance. However, note that the offline dataset we use here is very small (100,000 transition tuples), so we do not expect to see a high normalized score.\n", + "\n", + "Note 2: We have used average episodic returns in the offline data as a proxy for the performance of data collection agent (which is used when computing the normalized score).\n", + "\n", + "- An ideal way to do this would be to run the data collection agent/policy, at the end of the data collection phase, for a few episodes and take the average episodic returns.\n", + "- This would approximate the 'best' policy used to collect data. The real test for offline learning algorithms is to be able to beat this 'best' policy." + ], + "metadata": { + "id": "h5jzoRxiRGe6" + } + } + ] +} \ No newline at end of file