diff --git a/.circleci/docs/setup_env.sh b/.circleci/docs/setup_env.sh new file mode 100755 index 00000000000..496e57b29bd --- /dev/null +++ b/.circleci/docs/setup_env.sh @@ -0,0 +1,66 @@ +#apt-get update -y +#apt-get install software-properties-common -y +#add-apt-repository ppa:git-core/candidate -y +#apt-get update -y +#apt-get upgrade -y +#apt-get -y install libglfw3 libglew2.0 gcc curl g++ unzip \ +# wget sudo git cmake libz-dev \ +# zlib1g-dev python3.8 python3-pip ninja + +#yum install -y mesa-libGL freeglut egl-utils glew glfw +#yum install -y glew glfw +apt-get update && apt-get install -y git wget gcc g++ + +root_dir="$(pwd)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +os=Linux + +# 1. Install conda at ./conda +printf "* Installing conda\n" +wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" +bash ./miniconda.sh -b -f -p "${conda_dir}" + +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +printf "* Creating a test environment\n" +conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" + +printf "* Activating\n" +conda activate "${env_dir}" + +conda install -c conda-forge zlib -y + +pip3 install --upgrade pip --quiet --root-user-action=ignore + +printf "python version\n" +python --version + +pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu118 --quiet --root-user-action=ignore +#pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu --quiet --root-user-action=ignore + +printf "Installing tensordict\n" +pip3 install git+https://github.com/pytorch-labs/tensordict.git --quiet --root-user-action=ignore + +printf "Installing torchrl\n" +pip3 install -e . --quiet --root-user-action=ignore + +printf "Installing requirements\n" +pip3 install -r docs/requirements.txt --quiet --root-user-action=ignore +printf "Installed all dependencies\n" + +printf "smoke test\n" +PYOPENGL_PLATFORM=egl MUJOCO_GL=egl python3 -c """from torchrl.envs.libs.dm_control import DMControlEnv +print(DMControlEnv('cheetah', 'run').reset()) +""" + +printf "building docs...\n" +cd ./docs +#timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build SPHINXOPTS=-v ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi +PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build +cd .. +printf "done!\n" + +git clone --branch gh-pages https://github.com/pytorch-labs/tensordict.git docs/_local_build/tensordict +rm -rf docs/_local_build/tensordict/.git diff --git a/.circleci/unittest/linux_libs/scripts_habitat/install.sh b/.circleci/unittest/linux_libs/scripts_habitat/install.sh index 8fb340c567c..437900b3323 100755 --- a/.circleci/unittest/linux_libs/scripts_habitat/install.sh +++ b/.circleci/unittest/linux_libs/scripts_habitat/install.sh @@ -5,6 +5,7 @@ unset PYTORCH_VERSION # so no need to set PYTORCH_VERSION. # In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. apt-get update && apt-get install -y git wget gcc g++ +#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev set -e diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 1d595983439..636475823ca 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -10,91 +10,53 @@ on: workflow_dispatch: jobs: build_docs_job: + strategy: + matrix: + python_version: ["3.8"] # "3.8", "3.9", "3.10", "3.11" + cuda_arch_version: ["11.8"] # "11.6", "11.7" + fail-fast: false + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + runner: linux.g5.4xlarge.nvidia.gpu + repository: pytorch/rl + gpu-arch-type: cuda + gpu-arch-version: ${{ matrix.cuda_arch_version }} + docker-image: nvidia/cudagl:11.4.0-base + timeout: 45 + script: | + # Set env vars from matrix + export PYTHON_VERSION=${{ matrix.python_version }} + # Commenting these out for now because the GPU test are not working inside docker + export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} + export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" + # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines + #export CU_VERSION="cpu" + + echo "PYTHON_VERSION: $PYTHON_VERSION" + echo "CU_VERSION: $CU_VERSION" + + cd /work + + ## setup_env.sh + ./.circleci/docs/setup_env.sh + + deploy: + needs: build_docs_job runs-on: ubuntu-20.04 strategy: matrix: include: - - os: linux.4xlarge.nvidia.gpu + - os: linux.12xlarge python-version: 3.8 defaults: run: shell: bash -l {0} container: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu18.04 steps: - - name: Install deps - run: | - apt-get update -y - apt-get install software-properties-common -y - add-apt-repository ppa:git-core/candidate -y - apt-get update -y - apt-get upgrade -y - apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev - - name: Install rsync 📚 - run: | - apt-get update && apt-get install -y rsync - - name: Check ldd --version - run: ldd --version - name: Checkout uses: actions/checkout@v3 - # Update references - - name: Update pip - run: | - apt-get install python3.8 python3-pip -y - pip3 install --upgrade pip - - name: Setup conda - run: | - rm -rf $HOME/miniconda - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh - bash ~/miniconda.sh -b -p $HOME/miniconda - - name: setup Path - run: | - echo "$HOME/miniconda/bin" >> $GITHUB_PATH - echo "CONDA=$HOME/miniconda" >> $GITHUB_PATH - - name: create and activate conda env - run: | - $HOME/miniconda/bin/conda create --name build_binary python=${{ matrix.python-version }} - $HOME/miniconda/bin/conda info - $HOME/miniconda/bin/activate build_binary - - name: check python version - run: | - python --version - - name: Check git version - run: git version - - name: setup Path - run: | - echo /usr/local/bin >> $GITHUB_PATH - - name: Install PyTorch - shell: bash - run: | - python -m pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - - name: Install tensordict - run: | - python3 -m pip install git+https://github.com/pytorch-labs/tensordict.git - - name: Install TorchRL - run: | - python -m pip install -e . - - name: Test torchrl installation - shell: bash - run: | - mkdir _tmp - cd _tmp - python -c "import torchrl;from torchrl.envs.libs.dm_control import DMControlEnv" - cd .. - - name: Build the docset - id: build_doc - run: | - python -m pip install -r docs/requirements.txt - cd ./docs - timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi - cd .. - - name: Pull TensorDict docs - run: | - git clone --branch gh-pages https://github.com/pytorch-labs/tensordict.git docs/_local_build/tensordict - rm -rf docs/_local_build/tensordict/.git - - name: Get output time - run: echo "The time was ${{ steps.build.outputs.time }}" - name: Deploy - if: ${{ github.ref == 'refs/heads/main' }} + if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} uses: JamesIves/github-pages-deploy-action@releases/v4 with: token: ${{ secrets.GITHUB_TOKEN }} diff --git a/docs/Makefile b/docs/Makefile index d0c3cbf1020..a5ca087a556 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,7 +3,7 @@ # You can set these variables from the command line, and also # from the environment for the first two. -SPHINXOPTS ?= +SPHINXOPTS = -v SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = build diff --git a/docs/requirements.txt b/docs/requirements.txt index 2ce3140692d..a90869b836d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -12,16 +12,14 @@ sphinxcontrib-htmlhelp myst-parser docutils -functorch -gym[classic_control] torchvision dm_control atari-py ale-py -gym[accept-rom-license] +gym[classic_control,accept-rom-license] pygame tqdm ipython -imageio -imageio[ffmpeg] -imageio[pyav] +imageio[ffmpeg,pyav] +memory_profiler +pyrender diff --git a/docs/source/conf.py b/docs/source/conf.py index fdb1fc10c75..e278c653619 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -88,6 +88,9 @@ "filename_pattern": "reference/generated/tutorials/", # files to parse "notebook_images": "reference/generated/tutorials/media/", # images to parse "download_all_examples": True, + "abort_on_example_error": False, + "only_warn_on_example_error": True, + "show_memory": True, } napoleon_use_ivar = True diff --git a/docs/source/index.rst b/docs/source/index.rst index e56e32edfc9..79fa0cddb15 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,18 +34,19 @@ Basics ------ .. toctree:: - :maxdepth: 2 + :maxdepth: 1 - tutorials/torchrl_demo tutorials/coding_ppo + tutorials/pendulum tutorials/tensordict_tutorial tutorials/tensordict_module + tutorials/torchrl_demo Intermediate ------------ .. toctree:: - :maxdepth: 2 + :maxdepth: 1 tutorials/torch_envs tutorials/pretrained_models @@ -54,7 +55,7 @@ Advanced -------- .. toctree:: - :maxdepth: 2 + :maxdepth: 1 tutorials/multi_task tutorials/coding_ddpg diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index e0a1962bc19..34d6271811f 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -14,6 +14,40 @@ The :obj:`trainer.train()` method can be sketched as follows: .. code-block:: :caption: Trainer loops + >>> for batch in collector: + ... batch = self._process_batch_hook(batch) # "batch_process" + ... self._pre_steps_log_hook(batch) # "pre_steps_log" + ... self._pre_optim_hook() # "pre_optim_steps" + ... for j in range(self.optim_steps_per_batch): + ... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch" + ... losses = self.loss_module(sub_batch) + ... self._post_loss_hook(sub_batch) # "post_loss" + ... self.optimizer.step() + ... self.optimizer.zero_grad() + ... self._post_optim_hook() # "post_optim" + ... self._post_optim_log(sub_batch) # "post_optim_log" + ... self._post_steps_hook() # "post_steps" + ... self._post_steps_log_hook(batch) # "post_steps_log" + + There are 10 hooks that can be used in a trainer loop: + + >>> for batch in collector: + ... batch = self._process_batch_hook(batch) # "batch_process" + ... self._pre_steps_log_hook(batch) # "pre_steps_log" + ... self._pre_optim_hook() # "pre_optim_steps" + ... for j in range(self.optim_steps_per_batch): + ... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch" + ... losses = self.loss_module(sub_batch) + ... self._post_loss_hook(sub_batch) # "post_loss" + ... self.optimizer.step() + ... self.optimizer.zero_grad() + ... self._post_optim_hook() # "post_optim" + ... self._post_optim_log(sub_batch) # "post_optim_log" + ... self._post_steps_hook() # "post_steps" + ... self._post_steps_log_hook(batch) # "post_steps_log" + + There are 10 hooks that can be used in a trainer loop: + >>> for batch in collector: ... batch = self._process_batch_hook(batch) # "batch_process" ... self._pre_steps_log_hook(batch) # "pre_steps_log" diff --git a/knowledge_base/MUJOCO_INSTALLATION.md b/knowledge_base/MUJOCO_INSTALLATION.md index 65814cdf69a..dd080e2774a 100644 --- a/knowledge_base/MUJOCO_INSTALLATION.md +++ b/knowledge_base/MUJOCO_INSTALLATION.md @@ -129,6 +129,7 @@ issues when running `import mujoco_py` and some troubleshooting for each of them #include ^~~~~~~~~~~ ``` + _Solution_: make sure glew is installed (see above: `conda install -c conda-forge glew` or the `apt-get` version of it). 2. ``` @@ -136,14 +137,15 @@ issues when running `import mujoco_py` and some troubleshooting for each of them #include ^~~~~~~~~ ``` + _Solution_: This should disappear once `mesalib` is installed: `conda install -y -c conda-forge mesalib` -3. +4. ``` FileNotFoundError: [Errno 2] No such file or directory: 'patchelf' ``` - _Solution_: `pip install patchelf` -4. + _Solution_: `pip install patchelf` +5. ``` ImportError: /usr/lib/x86_64-linux-gnu/libOpenGL.so.0: undefined symbol: _glapi_tls_Current ``` @@ -155,7 +157,7 @@ issues when running `import mujoco_py` and some troubleshooting for each of them conda env config vars set LD_PRELOAD=/path/to/conda/envs/mujoco_env/x86_64-conda-linux-gnu/sysroot/usr/lib64/libGLdispatch.so.0 ``` -5. +6. ``` mujoco.FatalError: gladLoadGL error @@ -166,6 +168,7 @@ issues when running `import mujoco_py` and some troubleshooting for each of them **Sanity check** + To check that your mujoco-py has been built against the GPU, run ```python >>> import mujoco_py @@ -191,17 +194,18 @@ RuntimeError: Failed to initialize OpenGL 2. Rendered images are completely black. -> Make sure to call `env.render()` before reading the pixels. + _Solution_: Make sure to call `env.render()` before reading the pixels. 3. `patchelf` dependency is missing. -> Install using `conda install patchelf` or `pip install patchelf` + _Solution_: Install using `conda install patchelf` or `pip install patchelf` 4. Errors like "Onscreen rendering needs 101 device" -> Make sure to set `DISPLAY` environment variable correctly. + _Solution_: Make sure to set `DISPLAY` environment variable correctly. 5. `ImportError: Cannot initialize a headless EGL display.` - Make sure you have installed mujoco and all its dependencies (see instructions above). - Make sure you have set the `MUJOCO_GL=egl`. - Make sure you have a GPU accessible on your machine. + + _Solution_: Make sure you have installed mujoco and all its dependencies (see instructions above). + Make sure you have set the `MUJOCO_GL=egl`. + Make sure you have a GPU accessible on your machine. diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 8facef31323..6286f3418ed 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -691,6 +691,7 @@ def shutdown(self) -> None: if not self.env.is_closed: self.env.close() del self.env + return def __del__(self): try: @@ -1052,7 +1053,6 @@ def _shutdown_main(self) -> None: for idx in range(self.num_workers): self.pipes[idx].send((None, "close")) - for idx in range(self.num_workers): msg = self.pipes[idx].recv() if msg != "closed": raise RuntimeError(f"got {msg} but expected 'close'") @@ -1621,7 +1621,7 @@ def _main_async_collector( print(f"worker {idx} received {msg}") else: if verbose: - print(f"poll failed, j={j}") + print(f"poll failed, j={j}, worker={idx}") # default is "continue" (after first iteration) # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe # in that case, the main process probably expects the worker to continue collect data @@ -1638,7 +1638,10 @@ def _main_async_collector( # this means that our process has been waiting for a command from main in vain, while main was not # receiving data. # This will occur if main is busy doing something else (e.g. computing loss etc). + counter += _timeout + if verbose: + print(f"worker {idx} has counter {counter}") if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): raise RuntimeError( f"This process waited for {counter} seconds " diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 68d70145974..9d73ea9fdc7 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -595,7 +595,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten f"{self.__class__.__name__}.index(...)" ) index = index.nonzero().squeeze() - index = index.expand(*tensor_to_index.shape[:-1], index.shape[-1]) + index = index.expand((*tensor_to_index.shape[:-1], index.shape[-1])) return tensor_to_index.gather(-1, index) def _project(self, val: torch.Tensor) -> torch.Tensor: @@ -679,17 +679,17 @@ def __init__( if shape is not None and shape != maximum.shape: raise RuntimeError(err_msg) shape = maximum.shape - minimum = minimum.expand(*shape).clone() + minimum = minimum.expand(shape).clone() elif minimum.ndimension(): if shape is not None and shape != minimum.shape: raise RuntimeError(err_msg) shape = minimum.shape - maximum = maximum.expand(*shape).clone() + maximum = maximum.expand(shape).clone() elif shape is None: raise RuntimeError(err_msg) else: - minimum = minimum.expand(*shape).clone() - maximum = maximum.expand(*shape).clone() + minimum = minimum.expand(shape).clone() + maximum = maximum.expand(shape).clone() if minimum.numel() > maximum.numel(): maximum = maximum.expand_as(minimum).clone() @@ -1028,7 +1028,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten f" {self.__class__.__name__}.index(...)" ) index = index.nonzero().squeeze() - index = index.expand(*tensor_to_index.shape[:-1], index.shape[-1]) + index = index.expand((*tensor_to_index.shape[:-1], index.shape[-1])) return tensor_to_index.gather(-1, index) def is_in(self, val: torch.Tensor) -> bool: @@ -1203,7 +1203,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten out = [] for _index, _tensor_to_index in zip(indices, tensor_to_index): _index = _index.nonzero().squeeze() - _index = _index.expand(*_tensor_to_index.shape[:-1], _index.shape[-1]) + _index = _index.expand((*_tensor_to_index.shape[:-1], _index.shape[-1])) out.append(_tensor_to_index.gather(-1, _index)) return torch.cat(out, -1) @@ -1941,7 +1941,7 @@ def expand(self, *shape): ) out = CompositeSpec( { - key: value.expand(*shape, *value.shape[self.ndim :]) + key: value.expand((*shape, *value.shape[self.ndim :])) for key, value in tuple(self.items()) }, shape=shape, diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 5b54571d7a4..7dc749289a9 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -595,11 +595,21 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa be stored with the "action" key. """ + shape = torch.Size([]) if tensordict is None: tensordict = TensorDict( {}, device=self.device, batch_size=self.batch_size, _run_checks=False ) - action = self.action_spec.rand() + elif not self.batch_locked and not self.batch_size: + shape = tensordict.shape + elif not self.batch_locked and tensordict.shape != self.batch_size: + raise RuntimeError( + "The input tensordict and the env have a different batch size: " + f"env.batch_size={self.batch_size} and tensordict.batch_size={tensordict.shape}. " + f"Non batch-locked environment require the env batch-size to be either empty or to" + f" match the tensordict one." + ) + action = self.action_spec.rand(shape) tensordict.set("action", action) return self.step(tensordict) @@ -670,7 +680,8 @@ def rollout( if policy is None: def policy(td): - return td.set("action", self.action_spec.rand()) + self.rand_step(td) + return td tensordicts = [] for i in range(max_steps): diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index aece9e33dda..fbcd6d150cd 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -38,7 +38,9 @@ except ImportError as err: _has_dmc = False - IMPORT_ERR = str(err) + IMPORT_ERR = err +else: + IMPORT_ERR = None __all__ = ["DMControlEnv", "DMControlWrapper"] @@ -299,10 +301,8 @@ class DMControlEnv(DMControlWrapper): def __init__(self, env_name, task_name, **kwargs): if not _has_dmc: raise ImportError( - f"""dm_control python package was not found. Please install this dependency. -(Got the error message: {IMPORT_ERR}). -""" - ) + "dm_control python package was not found. Please install this dependency." + ) from IMPORT_ERR kwargs["env_name"] = env_name kwargs["task_name"] = task_name super().__init__(**kwargs) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e6aae646e98..4fa1faba4a6 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1593,7 +1593,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: return input_spec def __repr__(self) -> str: - if self.loc.numel() == 1 and self.scale.numel() == 1: + if self.initialized and (self.loc.numel() == 1 and self.scale.numel() == 1): return ( f"{self.__class__.__name__}(" f"loc={float(self.loc):4.4f}, scale" diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 7b34d927607..430f4632836 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -230,19 +230,11 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): # test dtypes real_tensordict = env.rollout(3) # keep empty structures, for example dict() - # real_tensordict = real_tensordict.exclude( - # *[ - # key - # for key in real_tensordict.keys(True) - # if (isinstance(key, str) and key.startswith("_")) - # or ( - # isinstance(key, tuple) and any(subkey.startswith("_") for subkey in key) - # ) - # ] - # ) for key, value in real_tensordict[..., -1].items(): _check_isin(key, value, env.observation_spec, env.input_spec) + print("check_env_specs succeeded!") + def _check_isin(key, value, obs_spec, input_spec): if key in {"reward", "done"}: diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 2c799e17fb5..364c0dec725 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -15,5 +15,6 @@ RewardNormalizer, SelectKeys, Trainer, + TrainerHookBase, UpdateWeights, ) diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index fa92fa6840c..a0f808295e6 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -1,17 +1,26 @@ # -*- coding: utf-8 -*- """ Coding DDPG using TorchRL -============================ +========================= +**Author**: `Vincent Moens `_ + """ ############################################################################## # This tutorial will guide you through the steps to code DDPG from scratch. -# DDPG (`Deep Deterministic Policy Gradient `_) -# is a simple continuous control algorithm. It essentially consists in -# learning a parametric value function for an action-observation pair, and +# +# DDPG (`Deep Deterministic Policy Gradient _`_) +# is a simple continuous control algorithm. It consists in learning a +# parametric value function for an action-observation pair, and # then learning a policy that outputs actions that maximise this value # function given a certain observation. # -# In this tutorial, you will learn: +# This tutorial is more than the PPO tutorial: it covers +# multiple topics that were left aside. We strongly advise the reader to go +# through the PPO tutorial first before trying out this one. The goal is to +# show how flexible torchrl is when it comes to writing scripts that can cover +# multiple use cases. +# +# Key learnings: # # - how to build an environment in TorchRL, including transforms # (e.g. data normalization) and parallel execution; @@ -22,15 +31,18 @@ # - and finally how to evaluate your model. # # This tutorial assumes the reader is familiar with some of TorchRL primitives, -# such as ``TensorDict`` and ``TensorDictModules``, although it should be +# such as :class:`tensordict.TensorDict` and +# :class:`tensordict.nn.TensorDictModules`, although it should be # sufficiently transparent to be understood without a deep understanding of # these classes. # # We do not aim at giving a SOTA implementation of the algorithm, but rather # to provide a high-level illustration of TorchRL features in the context of # this algorithm. - -# Make all the necessary imports for training +# +# Imports +# ------- +# # sphinx_gallery_start_ignore import warnings @@ -39,7 +51,6 @@ # sphinx_gallery_end_ignore from copy import deepcopy -from typing import Optional import numpy as np import torch @@ -76,30 +87,47 @@ ############################################################################### # Environment -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Let us start by building the environment. +# ----------- +# +# In most algorithms, the first thing that needs to be taken care of is the +# construction of the environmet as it conditions the remainder of the +# training script. # -# For this example, we will be using the cheetah task. The goal is to make +# For this example, we will be using the ``"cheetah"`` task. The goal is to make # a half-cheetah run as fast as possible. # # In TorchRL, one can create such a task by relying on dm_control or gym: # -# env = GymEnv("HalfCheetah-v4") +# .. code-block:: python +# +# env = GymEnv("HalfCheetah-v4") # # or # -# env = DMControlEnv("cheetah", "run") +# .. code-block:: python +# +# env = DMControlEnv("cheetah", "run") +# +# By default, these environment disable rendering. Training from states is +# usually easier than training from images. To keep things simple, we focus +# on learning from states only. To pass the pixels to the tensordicts that +# are collected by :func:`env.step()`, simply pass the ``from_pixels=True`` +# argument to the constructor: +# +# .. code-block:: python +# +# env = GymEnv("HalfCheetah-v4", from_pixels=True, pixels_only=True) +# +# We write a :func:`make_env` helper funciton that will create an environment +# with either one of the two backends considered above (dm-control or gym). # -# We only consider the state-based environment, but if one wishes to use a -# pixel-based environment, this can be done via the keyword argument -# ``from_pixels=True`` which is passed when calling ``GymEnv`` or -# ``DMControlEnv``. + +env_library = None +env_name = None def make_env(): - """ - Create a base env - """ + """Create a base env.""" global env_library global env_name @@ -127,31 +155,44 @@ def make_env(): ############################################################################### # Transforms -# ------------------------------ +# ^^^^^^^^^^ +# # Now that we have a base environment, we may want to modify its representation -# to make it more policy-friendly. +# to make it more policy-friendly. In TorchRL, transforms are appended to the +# base environment in a specialized :class:`torchr.envs.TransformedEnv` class. # -# It is common in DDPG to rescale the reward using some heuristic value. We -# will multiply the reward by 5 in this example. +# - It is common in DDPG to rescale the reward using some heuristic value. We +# will multiply the reward by 5 in this example. # -# If we are using dm_control, it is important also to transform the actions -# to double precision numbers as this is the dtype expected by the library. +# - If we are using :mod:`dm_control`, it is also important to build an interface +# between the simulator which works with double precision numbers, and our +# script which presumably uses single precision ones. This transformation goes +# both ways: when calling :func:`env.step`, our actions will need to be +# represented in double precision, and the output will need to be transformed +# to single precision. +# The :class:`torchrl.envs.DoubleToFloat` transform does exactly this: the +# ``in_keys`` list refers to the keys that will need to be transformed from +# double to float, while the ``in_keys_inv`` refers to those that need to +# be transformed to double before being passed to the environment. +# +# - We concatenate the state keys together using the :class:`torchrl.envs.CatTensors` +# transform. +# +# - Finally, we also leave the possibility of normalizing the states: we will +# take care of computing the normalizing constants later on. # -# We also leave the possibility to normalize the states: we will take care of -# computing the normalizing constants later on. def make_transformed_env( env, - stats=None, ): - """ - Apply transforms to the env (such as reward scaling and state normalization) - """ + """Apply transforms to the env (such as reward scaling and state normalization).""" env = TransformedEnv(env) - # we append transforms one by one, although we might as well create the transformed environment using the `env = TransformedEnv(base_env, transforms)` syntax. + # we append transforms one by one, although we might as well create the + # transformed environment using the `env = TransformedEnv(base_env, transforms)` + # syntax. env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) double_to_float_list = [] @@ -166,20 +207,18 @@ def make_transformed_env( # We concatenate all states into a single "observation_vector" # even if there is a single tensor, it'll be renamed in "observation_vector". - # This facilitates the downstream operations as we know the name of the output tensor. - # In some environments (not half-cheetah), there may be more than one observation vector: in this case this code snippet will concatenate them all. + # This facilitates the downstream operations as we know the name of the + # output tensor. + # In some environments (not half-cheetah), there may be more than one + # observation vector: in this case this code snippet will concatenate them + # all. selected_keys = list(env.observation_spec.keys()) out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) - # we normalize the states - if stats is None: - _stats = {"loc": 0.0, "scale": 1.0} - else: - _stats = stats - env.append_transform( - ObservationNorm(**_stats, in_keys=[out_key], standard_normal=True) - ) + # we normalize the states, but for now let's just instantiate a stateless + # version of the transform + env.append_transform(ObservationNorm(in_keys=[out_key], standard_normal=True)) double_to_float_list.append(out_key) env.append_transform( @@ -191,24 +230,70 @@ def make_transformed_env( return env +############################################################################### +# Normalization of the observations +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# To compute the normalizing statistics, we run an arbitrary number of random +# steps in the environment and compute the mean and standard deviation of the +# collected observations. The :func:`ObservationNorm.init_stats()` method can +# be used for this purpose. To get the summary statistics, we create a dummy +# environment and run it for a given number of steps, collect data over a given +# number of steps and compute its summary statistics. +# + + +def get_env_stats(): + """Gets the stats of an environment.""" + proof_env = make_transformed_env(make_env()) + proof_env.set_seed(seed) + t = proof_env.transform[2] + t.init_stats(init_env_steps) + transform_state_dict = t.state_dict() + proof_env.close() + return transform_state_dict + + ############################################################################### # Parallel execution -# ------------------------------ +# ^^^^^^^^^^^^^^^^^^ +# # The following helper function allows us to run environments in parallel. -# One can choose between running each base env in a separate process and -# execute the transform in the main process, or execute the transforms in -# parallel. To leverage the vectorization capabilities of PyTorch, we adopt +# Running environments in parallel can significantly speed up the collection +# throughput. When using transformed environment, we need to choose whether we +# want to execute the transform individually for each environment, or +# centralize the data and transform it in batch. Both approaches are easy to +# code: +# +# .. code-block:: python +# +# env = ParallelEnv( +# lambda: TransformedEnv(GymEnv("HalfCheetah-v4"), transforms), +# num_workers=4 +# ) +# env = TransformedEnv( +# ParallelEnv(lambda: GymEnv("HalfCheetah-v4"), num_workers=4), +# transforms +# ) +# +# To leverage the vectorization capabilities of PyTorch, we adopt # the first method: +# def parallel_env_constructor( - stats, - **env_kwargs, + transform_state_dict, ): if env_per_collector == 1: - env_creator = EnvCreator( - lambda: make_transformed_env(make_env(), stats, **env_kwargs) - ) + + def make_t_env(): + env = make_transformed_env(make_env()) + env.transform[2].init_stats(3) + env.transform[2].loc.copy_(transform_state_dict["loc"]) + env.transform[2].scale.copy_(transform_state_dict["scale"]) + return env + + env_creator = EnvCreator(make_t_env) return env_creator parallel_env = ParallelEnv( @@ -217,78 +302,59 @@ def parallel_env_constructor( create_env_kwargs=None, pin_memory=False, ) - env = make_transformed_env(parallel_env, stats, **env_kwargs) + env = make_transformed_env(parallel_env) + # we call `init_stats` for a limited number of steps, just to instantiate + # the lazy buffers. + env.transform[2].init_stats(3, cat_dim=1, reduce_dim=[0, 1]) + env.transform[2].load_state_dict(transform_state_dict) return env -############################################################################### -# Normalization of the observations -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# To compute the normalizing statistics, we run an arbitrary number of random -# steps in the environment and compute the mean and standard deviation of the -# collected observations: - - -def get_stats_random_rollout(proof_environment, key: Optional[str] = None): - print("computing state stats") - n = 0 - td_stats = [] - while n < init_env_steps: - _td_stats = proof_environment.rollout(max_steps=init_env_steps) - n += _td_stats.numel() - _td_stats_select = _td_stats.to_tensordict().select(key).cpu() - if not len(list(_td_stats_select.keys())): - raise RuntimeError( - f"key {key} not found in tensordict with keys {list(_td_stats.keys())}" - ) - td_stats.append(_td_stats_select) - del _td_stats, _td_stats_select - td_stats = torch.cat(td_stats, 0) - - m = td_stats.get(key).mean(dim=0) - s = td_stats.get(key).std(dim=0) - m[s == 0] = 0.0 - s[s == 0] = 1.0 - - print( - f"stats computed for {td_stats.numel()} steps. Got: \n" - f"loc = {m}, \n" - f"scale: {s}" - ) - if not torch.isfinite(m).all(): - raise RuntimeError("non-finite values found in mean") - if not torch.isfinite(s).all(): - raise RuntimeError("non-finite values found in sd") - stats = {"loc": m, "scale": s} - return stats - - -def get_env_stats(): - """ - Gets the stats of an environment - """ - proof_env = make_transformed_env(make_env(), None) - proof_env.set_seed(seed) - stats = get_stats_random_rollout( - proof_env, - key="observation_vector", - ) - # make sure proof_env is closed - proof_env.close() - return stats - - ############################################################################### # Building the model -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Let us now build the DDPG actor and QValue network. +# ------------------ +# +# We now turn to the setup of the model and loss function. DDPG requires a +# value network, trained to estimate the value of a state-action pair, and a +# parametric actor that learns how to select actions that maximize this value. +# In this tutorial, we will be using two independent networks for these +# components. +# +# Recall that building a torchrl module requires two steps: +# +# - writing the :class:`torch.nn.Module` that will be used as network +# - wrapping the network in a :class:`tensordict.nn.TensorDictModule` where the +# data flow is handled by specifying the input and output keys. +# +# In more complex scenarios, :class:`tensordict.nn.TensorDictSequential` can +# also be used. +# +# In :func:`make_ddpg_actor`, we use a :class:`torchrl.modules.ProbabilisticActor` +# object to wrap our policy network. Since DDPG is a deterministic algorithm, +# this is not strictly necessary. We rely on this class to map the output +# action to the appropriate domain. Alternatively, one could perfectly use a +# non-linearity such as :class:`torch.tanh` to map the output to the right +# domain. +# +# The Q-Value network is wrapped in a :class:`torchrl.modules.ValueOperator` +# that automatically sets the ``out_keys`` to ``"state_action_value`` for q-value +# networks and ``state_value`` for other value networks. +# +# Since we use lazy modules, it is necessary to materialize the lazy modules +# before being able to move the policy from device to device and achieve other +# operations. Hence, it is good practice to run the modules with a small +# sample of data. For this purpose, we generate fake data from the +# environment specs. +# def make_ddpg_actor( - stats, + transform_state_dict, device="cpu", ): - proof_environment = make_transformed_env(make_env(), stats) + proof_environment = make_transformed_env(make_env()) + proof_environment.transform[2].init_stats(3) + proof_environment.transform[2].load_state_dict(transform_state_dict) env_specs = proof_environment.specs out_features = env_specs["action_spec"].shape[0] @@ -332,7 +398,8 @@ def make_ddpg_actor( # init: since we have lazy layers, we should run the network # once to initialize them with torch.no_grad(), set_exploration_mode("random"): - td = proof_environment.rollout(max_steps=4) + td = proof_environment.fake_tensordict() + td = td.expand((*td.shape, 2)) td = td.to(device) actor(td) qnet(td) @@ -342,17 +409,22 @@ def make_ddpg_actor( ############################################################################### # Evaluator: building your recorder object -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ---------------------------------------- +# # As the training data is obtained using some exploration strategy, the true # performance of our algorithm needs to be assessed in deterministic mode. We # do this using a dedicated class, ``Recorder``, which executes the policy in # the environment at a given frequency and returns some statistics obtained -# from these simulations. The following helper function builds this object: +# from these simulations. +# +# The following helper function builds this object: -def make_recorder(actor_model_explore, stats): +def make_recorder(actor_model_explore, transform_state_dict): base_env = make_env() - recorder = make_transformed_env(base_env, stats) + recorder = make_transformed_env(base_env) + recorder.transform[2].init_stats(3) + recorder.transform[2].load_state_dict(transform_state_dict) recorder_obj = Recorder( record_frames=1000, @@ -367,18 +439,21 @@ def make_recorder(actor_model_explore, stats): ############################################################################### # Replay buffer -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Replay buffers come in two flavours: prioritized (where some error signal +# ------------- +# +# Replay buffers come in two flavors: prioritized (where some error signal # is used to give a higher likelihood of sampling to some items than others) # and regular, circular experience replay. # -# We also provide a special storage, names LazyMemmapStorage, that will +# TorchRL replay buffers are composable: one can pick up the storage, sampling +# and writing strategies. It is also possible to # store tensors on physical memory using a memory-mapped array. The following # function takes care of creating the replay buffer with the desired # hyperparameters: +# -def make_replay_buffer(make_replay_buffer=3): +def make_replay_buffer(buffer_size, prefetch=3): if prb: sampler = PrioritizedSampler( max_capacity=buffer_size, @@ -395,93 +470,141 @@ def make_replay_buffer(make_replay_buffer=3): ), sampler=sampler, pin_memory=False, - prefetch=make_replay_buffer, + prefetch=prefetch, ) return replay_buffer ############################################################################### # Hyperparameters -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# After having written all our helper functions, it is now time to set the +# --------------- +# +# After having written our helper functions, it is time to set the # experiment hyperparameters: -backend = "gym" # or "dm_control" -frame_skip = 2 # if this value is changed, the number of frames collected etc. need to be adjusted +############################################################################### +# Environment +# ^^^^^^^^^^^ + +# The backend can be gym or dm_control +backend = "gym" + +exp_name = "cheetah" + +# frame_skip batches multiple step together with a single action +# If > 1, the other frame counts (e.g. frames_per_batch, total_frames) need to +# be adjusted to have a consistent total number of frames collected across +# experiments. +frame_skip = 2 from_pixels = False +# Scaling the reward helps us control the signal magnitude for a more +# efficient learning. reward_scaling = 5.0 -# execute on cuda if available +# Number of random steps used as for stats computation using ObservationNorm +init_env_steps = 1000 + +# Exploration: Number of frames before OU noise becomes null +annealing_frames = 1000000 // frame_skip + +############################################################################### +# Collection +# ^^^^^^^^^^ + +# We will execute the policy on cuda if available device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda:0") ) -init_env_steps = 1000 # number of random steps used as for stats computation -env_per_collector = 2 # number of environments in each data collector +# Number of environments in each data collector +env_per_collector = 2 + +# Total frames we will use during training. Scale up to 500K - 1M for a more +# meaningful training +total_frames = 5000 // frame_skip +# Number of frames returned by the collector at each iteration of the outer loop +frames_per_batch = 1000 // frame_skip +init_random_frames = 0 +# We'll be using the MultiStep class to have a less myopic representation of +# upcoming states +n_steps_forward = 3 -env_library = None # overwritten because global in env maker -env_name = None # overwritten because global in env maker +# record every 10 batch collected +record_interval = 10 + +############################################################################### +# Optimizer and optimization +# ^^^^^^^^^^^^^^^^^^^^^^^^^^ -exp_name = "cheetah" -annealing_frames = ( - 1000000 // frame_skip -) # Number of frames before OU noise becomes null lr = 5e-4 weight_decay = 0.0 -total_frames = 5000 // frame_skip -init_random_frames = 0 -# init_random_frames = 5000 // frame_skip # Number of random frames used as warm-up -optim_steps_per_batch = 32 # Number of iterations of the inner loop +# UTD: Number of iterations of the inner loop +update_to_data = 32 batch_size = 128 -frames_per_batch = ( - 1000 // frame_skip -) # Number of frames returned by the collector at each iteration of the outer loop + +############################################################################### +# Model +# ^^^^^ + gamma = 0.99 tau = 0.005 # Decay factor for the target network -prb = True # If True, a Prioritized replay buffer will be used -buffer_size = min( - total_frames, 1000000 // frame_skip -) # Number of frames stored in the buffer -buffer_scratch_dir = "/tmp/" -n_steps_forward = 3 - -record_interval = 10 # record every 10 batch collected # Network specs num_cells = 64 num_layers = 2 +############################################################################### +# Replay buffer +# ^^^^^^^^^^^^^ + +# If True, a Prioritized replay buffer will be used +prb = True +# Number of frames stored in the buffer +buffer_size = min(total_frames, 1000000 // frame_skip) +buffer_scratch_dir = "/tmp/" + seed = 0 ############################################################################### -# **Note**: for fast rendering of the tutorial ``total_frames`` hyperparameter -# was set to a very low number. To get a reasonable performance, use a greater -# value e.g. 1000000. -# # Initialization -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# -------------- +# # To initialize the experiment, we first acquire the observation statistics, # then build the networks, wrap them in an exploration wrapper (following the # seminal DDPG paper, we used an Ornstein-Uhlenbeck process to add noise to the # sampled actions). -torch.manual_seed(0) -np.random.seed(0) -# get stats for normalization -stats = get_env_stats() +# Seeding +torch.manual_seed(seed) +np.random.seed(seed) + +############################################################################### +# Normalization stats +# ^^^^^^^^^^^^^^^^^^^ + +transform_state_dict = get_env_stats() + +############################################################################### +# Models: policy and q-value network +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Actor and qnet instantiation actor, qnet = make_ddpg_actor( - stats=stats, + transform_state_dict=transform_state_dict, device=device, ) if device == torch.device("cpu"): actor.share_memory() -# Target network + +############################################################################### +# We create a copy of the q-value network to be used as target network + qnet_target = deepcopy(qnet).requires_grad_(False) -# Exploration wrappers: +############################################################################### +# The policy is wrapped in a :class:`torchrl.modules.OrnsteinUhlenbeckProcessWrapper` +# exploration module: + actor_model_explore = OrnsteinUhlenbeckProcessWrapper( actor, annealing_num_steps=annealing_frames, @@ -489,18 +612,36 @@ def make_replay_buffer(make_replay_buffer=3): if device == torch.device("cpu"): actor_model_explore.share_memory() -# Environment setting: +############################################################################### +# Parallel environment creation +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# We pass the stats computed earlier to normalize the output of our +# environment: + create_env_fn = parallel_env_constructor( - stats=stats, + transform_state_dict=transform_state_dict, ) ############################################################################### # Data collector -# ------------------------------ -# Creating the data collector is a crucial step in an RL experiment. TorchRL -# provides a couple of classes to collect data in parallel. Here we will use -# ``MultiaSyncDataCollector``, a data collector that will be executed in an -# async manner (i.e. data will be collected while the policy is being optimized). +# ^^^^^^^^^^^^^^ +# +# TorchRL provides specialized classes to help you collect data by executing +# the policy in the environment. These "data collectors" iteratively compute +# the action to be executed at a given time, then execute a step in the +# environment and reset it when required. +# Data collectors are designed to help developers have a tight control +# on the number of frames per batch of data, on the (a)sync nature of this +# collection and on the resources allocated to the data collection (e.g. GPU, +# number of workers etc). +# +# Here we will use +# :class:`torchrl.collectors.MultiaSyncDataCollector`, a data collector that +# will be executed in an async manner (i.e. data will be collected while +# the policy is being optimized). With the :class:`MultiaSyncDataCollector`, +# multiple workers are running rollouts separately. When a batch is asked, it +# is gathered from the first worker that can provide it. # # The parameters to specify are: # @@ -514,14 +655,24 @@ def make_replay_buffer(make_replay_buffer=3): # # - the number of frames in each batch collected, # - the number of random steps executed independently from the policy, -# - the devices used for policy execution, and -# - data transmission. +# - the devices used for policy execution +# - the devices used to store data before the data is passed to the main +# process. # -# The ``MultiStep`` object passed as postproc makes it so that the rewards -# of the n upcoming steps are added (with some discount factor) and the next -# observation is changed to be the n-step forward observation. +# Collectors also accept post-processing hooks. +# For instance, the :class:`torchrl.data.postprocs.MultiStep` class passed as +# ``postproc`` makes it so that the rewards of the ``n`` upcoming steps are +# summed (with some discount factor) and the next observation is changed to +# be the n-step forward observation. One could pass other transforms too: +# using :class:`tensordict.nn.TensorDictModule` and +# :class:`tensordict.nn.TensorDictSequential` we can seamlessly append a +# wide range of transforms to our collector. + +if n_steps_forward > 0: + multistep = MultiStep(n_steps_max=n_steps_forward, gamma=gamma) +else: + multistep = None -# Batch collector: collector = MultiaSyncDataCollector( create_env_fn=[create_env_fn, create_env_fn], policy=actor_model_explore, @@ -530,33 +681,37 @@ def make_replay_buffer(make_replay_buffer=3): frames_per_batch=frames_per_batch, init_random_frames=init_random_frames, reset_at_each_iter=False, - postproc=MultiStep(n_steps_max=n_steps_forward, gamma=gamma) - if n_steps_forward > 0 - else None, + postproc=multistep, split_trajs=True, devices=[device, device], # device for execution storing_devices=[device, device], # device where data will be stored and passed - seed=None, pin_memory=False, update_at_each_batch=False, exploration_mode="random", ) + collector.set_seed(seed) ############################################################################### -# We can now create the replay buffer as part of the initialization. +# Replay buffer +# ^^^^^^^^^^^^^ +# -# Replay buffer: -replay_buffer = make_replay_buffer() +replay_buffer = make_replay_buffer(buffer_size, prefetch=3) + +############################################################################### +# Recorder +# ^^^^^^^^ -# Trajectory recorder -recorder = make_recorder(actor_model_explore, stats) +recorder = make_recorder(actor_model_explore, transform_state_dict) ############################################################################### +# Optimizer +# ^^^^^^^^^ +# # Finally, we will use the Adam optimizer for the policy and value network, # with the same learning rate for both. -# Optimizers optimizer_actor = optim.Adam(actor.parameters(), lr=lr, weight_decay=weight_decay) optimizer_qnet = optim.Adam(qnet.parameters(), lr=lr, weight_decay=weight_decay) total_collection_steps = total_frames // frames_per_batch @@ -569,43 +724,50 @@ def make_replay_buffer(make_replay_buffer=3): ) ############################################################################### -# Time to train the policy! -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Some notes about the following cell: -# -# - ``hold_out_net`` is a TorchRL context manager that temporarily sets -# ``requires_grad`` to False for a set of network parameters. This is used to -# prevent ``backward`` to write gradients on parameters that need not to be -# differentiated given the loss at hand. -# - The value network is designed using the ``ValueOperator`` TensorDictModule -# subclass. This class will write a ``"state_action_value"`` if one of its -# ``in_keys`` is named "action", otherwise it will assume that only the +# Time to train the policy +# ------------------------ +# +# Some notes about the following training loop: +# +# - :func:`torchrl.objectives.utils.hold_out_net` is a TorchRL context manager +# that temporarily sets :func:`torch.Tensor.requires_grad_()` to False for +# a designated set of network parameters. This is used to +# prevent :func:`torch.Tensor.backward()`` from writing gradients on +# parameters that need not to be differentiated given the loss at hand. +# - The value network is designed using the +# :class:`torchrl.modules.ValueOperator` subclass from +# :class:`tensordict.nn.TensorDictModule` class. As explained earlier, +# this class will write a ``"state_action_value"`` entry if one of its +# ``in_keys`` is named ``"action"``, otherwise it will assume that only the # state-value is returned and the output key will simply be ``"state_value"``. # In the case of DDPG, the value if of the state-action pair, -# hence the first name is used. -# - The ``step_mdp`` helper function returns a new TensorDict that essentially -# does the ``obs = next_obs`` step. In other words, it will return a new -# tensordict where the values that are related to the next state (next -# observations of various type) are selected and written as if they were -# current. This makes it possible to pass this new tensordict to the policy or +# hence the ``"state_action_value"`` will be used. +# - The :func:`torchrl.envs.utils.step_mdp(tensordict)` helper function is the +# equivalent of the ``obs = next_obs`` command found in multiple RL +# algorithms. It will return a new :class:`tensordict.TensorDict` instance +# that contains all the data that will need to be used in the next iteration. +# This makes it possible to pass this new tensordict to the policy or # value network. # - When using prioritized replay buffer, a priority key is added to the # sampled tensordict (named ``"td_error"`` by default). Then, this -# TensorDict will be fed back to the replay buffer using the ``update_priority`` +# TensorDict will be fed back to the replay buffer using the +# :func:`torchrl.data.replay_buffers.TensorDictReplayBuffer.update_tensordict_priority` # method. Under the hood, this method will read the index present in the # TensorDict as well as the priority value, and update its list of priorities # at these indices. # - TorchRL provides optimized versions of the loss functions (such as this one) # where one only needs to pass a sampled tensordict and obtains a dictionary -# of losses and metadata in return (see ``torchrl.objectives`` for more +# of losses and metadata in return (see :mod:`torchrl.objectives` for more # context). Here we write the full loss function in the optimization loop -# for transparency. Similarly, the target network updates are written -# explicitely but TorchRL provides a couple of dedicated classes for this -# (see ``torchrl.objectives.SoftUpdate`` and ``torchrl.objectives.HardUpdate``). -# - After each collection of data, we call ``collector.update_policy_weights_()``, +# for transparency. +# Similarly, the target network updates are written explicitly but +# TorchRL provides a couple of dedicated classes for this +# (see :class:`torchrl.objectives.SoftUpdate` and +# :class:`torchrl.objectives.HardUpdate`). +# - After each collection of data, we call :func:`collector.update_policy_weights_()`, # which will update the policy network weights on the data collector. If the # code is executed on cpu or with a single cuda device, this part can be -# ommited. If the collector is executed on another device, then its weights +# omitted. If the collector is executed on another device, then its weights # must be synced with those on the main, training process and this method # should be incorporated in the training loop (ideally early in the loop in # async settings, and at the end of it in sync settings). @@ -643,7 +805,7 @@ def make_replay_buffer(make_replay_buffer=3): # optimization steps if collected_frames >= init_random_frames: - for _ in range(optim_steps_per_batch): + for _ in range(update_to_data): # sample from replay buffer sampled_tensordict = replay_buffer.sample(batch_size).clone() @@ -669,7 +831,9 @@ def make_replay_buffer(make_replay_buffer=3): optimizer_qnet.step() optimizer_qnet.zero_grad() - # compute loss for actor and backprop: the actor must maximise the state-action value, hence the loss is the neg value of this. + # compute loss for actor and backprop: + # the actor must maximise the state-action value, hence the loss + # is the neg value of this. sampled_tensordict_actor = sampled_tensordict.select(*actor.in_keys) with hold_out_net(qnet): qnet(actor(sampled_tensordict_actor)) @@ -694,7 +858,7 @@ def make_replay_buffer(make_replay_buffer=3): ) td_record = recorder(None) if td_record is not None: - rewards_eval.append((i, td_record["r_evaluation"])) + rewards_eval.append((i, td_record["r_evaluation"].item())) if len(rewards_eval): pbar.set_description( f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), reward eval: reward: {rewards_eval[-1][1]: 4.4f}" @@ -707,15 +871,17 @@ def make_replay_buffer(make_replay_buffer=3): scheduler2.step() collector.shutdown() +del collector ############################################################################### # Experiment results -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ------------------ +# # We make a simple plot of the average rewards during training. We can observe # that our policy learned quite well to solve the task. # # **Note**: As already mentioned above, to get a more reasonable performance, -# use a greater value for ``total_frames`` e.g. 1000000. +# use a greater value for ``total_frames`` e.g. 1M. plt.figure() plt.plot(*zip(*rewards), label="training") @@ -727,7 +893,8 @@ def make_replay_buffer(make_replay_buffer=3): ############################################################################### # Sampling trajectories and using TD(lambda) -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ------------------------------------------ +# # TD(lambda) is known to be less biased than the regular TD-error we used in # the previous example. To use it, however, we need to sample trajectories and # not single transitions. @@ -735,37 +902,38 @@ def make_replay_buffer(make_replay_buffer=3): # We modify the previous example to make this possible. # # The first modification consists in building a replay buffer that stores -# trajectories (and not transitions). We'll collect trajectories of (at most) -# 250 steps (note that the total trajectory length is actually 1000, but we -# collect batches of 500 transitions obtained over 2 environments running in +# trajectories (and not transitions). +# +# Specifically, we'll collect trajectories of (at most) +# 250 steps (note that the total trajectory length is actually 1000 frames, but +# we collect batches of 500 transitions obtained over 2 environments running in # parallel, hence only 250 steps per trajectory are collected at any given -# time). Hence, we'll devide our replay buffer size by 250: +# time). Hence, we'll divide our replay buffer size by 250: buffer_size = 100000 // frame_skip // 250 print("the new buffer size is", buffer_size) batch_size_traj = max(4, batch_size // 250) print("the new batch size for trajectories is", batch_size_traj) -############################################################################### - n_steps_forward = 0 # disable multi-step for simplicity ############################################################################### # The following code is identical to the initialization we made earlier: -torch.manual_seed(0) -np.random.seed(0) +torch.manual_seed(seed) +np.random.seed(seed) # get stats for normalization -stats = get_env_stats() +transform_state_dict = get_env_stats() # Actor and qnet instantiation actor, qnet = make_ddpg_actor( - stats=stats, + transform_state_dict=transform_state_dict, device=device, ) if device == torch.device("cpu"): actor.share_memory() + # Target network qnet_target = deepcopy(qnet).requires_grad_(False) @@ -779,9 +947,13 @@ def make_replay_buffer(make_replay_buffer=3): # Environment setting: create_env_fn = parallel_env_constructor( - stats=stats, + transform_state_dict=transform_state_dict, ) # Batch collector: +if n_steps_forward > 0: + multistep = MultiStep(n_steps_max=n_steps_forward, gamma=gamma) +else: + multistep = None collector = MultiaSyncDataCollector( create_env_fn=[create_env_fn, create_env_fn], policy=actor_model_explore, @@ -790,10 +962,8 @@ def make_replay_buffer(make_replay_buffer=3): frames_per_batch=frames_per_batch, init_random_frames=init_random_frames, reset_at_each_iter=False, - postproc=MultiStep(n_steps_max=n_steps_forward, gamma=gamma) - if n_steps_forward > 0 - else None, - split_trajs=True, + postproc=multistep, + split_trajs=False, devices=[device, device], # device for execution storing_devices=[device, device], # device where data will be stored and passed seed=None, @@ -804,10 +974,10 @@ def make_replay_buffer(make_replay_buffer=3): collector.set_seed(seed) # Replay buffer: -replay_buffer = make_replay_buffer(0) +replay_buffer = make_replay_buffer(buffer_size, prefetch=0) # trajectory recorder -recorder = make_recorder(actor_model_explore, stats) +recorder = make_recorder(actor_model_explore, transform_state_dict) # Optimizers optimizer_actor = optim.Adam(actor.parameters(), lr=lr, weight_decay=weight_decay) @@ -822,9 +992,10 @@ def make_replay_buffer(make_replay_buffer=3): ) ############################################################################### -# The training loop needs to be modified. First, whereas before extending the -# replay buffer we used to flatten the collected data, this won't be the case -# anymore. To understand why, let's check the output shape of the data collector: +# The training loop needs to be slightly adapted. +# First, whereas before extending the replay buffer we used to flatten the +# collected data, this won't be the case anymore. To understand why, let's +# check the output shape of the data collector: for data in collector: print(data.shape) @@ -834,7 +1005,8 @@ def make_replay_buffer(make_replay_buffer=3): # We see that our data has shape ``[2, 250]`` as expected: 2 envs, each # returning 250 frames. # -# Let's import the td_lambda function +# Let's import the td_lambda function: +# from torchrl.objectives.value.functional import vec_td_lambda_advantage_estimate @@ -848,13 +1020,6 @@ def make_replay_buffer(make_replay_buffer=3): # to compute gradients. This ensures that do not have batches that are # 'too big' but still compute an accurate return. # -# Note that when storing tensordicts the replay buffer, we must change their -# batch size: indeed, we will be storing an "index" (and possibly an -# priority) key in the stored tensordicts that will not have a time dimension. -# Because of this, when sampling from the replay buffer, we remove the keys -# that do not have a time dimension, change the batch size to -# ``torch.Size([batch, time])``, compute our loss and then revert the -# batch size to ``torch.Size([batch])``. rewards = [] rewards_eval = [] @@ -875,25 +1040,21 @@ def make_replay_buffer(make_replay_buffer=3): if r0 is None: r0 = tensordict["reward"].mean().item() - # pbar.update(tensordict.numel()) # extend the replay buffer with the new data - tensordict.batch_size = tensordict.batch_size[ - :1 - ] # this is necessary for prioritized replay buffers: we will assign one priority value to each element, hence the batch_size must comply with the number of priority values current_frames = tensordict.numel() - collected_frames += tensordict["collector", "mask"].sum() + collected_frames += current_frames replay_buffer.extend(tensordict.cpu()) # optimization steps if collected_frames >= init_random_frames: - for _ in range(optim_steps_per_batch): + for _ in range(update_to_data): # sample from replay buffer sampled_tensordict = replay_buffer.sample(batch_size_traj) - # reset the batch size temporarily, and exclude index whose shape is incompatible with the new size + # reset the batch size temporarily, and exclude index + # whose shape is incompatible with the new size index = sampled_tensordict.get("index") sampled_tensordict.exclude("index", inplace=True) - sampled_tensordict.batch_size = [batch_size_traj, 250] # compute loss for qnet and backprop with hold_out_net(actor): @@ -905,7 +1066,8 @@ def make_replay_buffer(make_replay_buffer=3): next_value = next_tensordict["state_action_value"] assert not next_value.requires_grad - # This is the crucial bit: we'll compute the TD(lambda) instead of a simple single step estimate + # This is the crucial part: we'll compute the TD(lambda) + # instead of a simple single step estimate done = sampled_tensordict["done"] reward = sampled_tensordict["reward"] value = qnet(sampled_tensordict.view(-1)).view(sampled_tensordict.shape)[ @@ -956,7 +1118,7 @@ def make_replay_buffer(make_replay_buffer=3): ) td_record = recorder(None) if td_record is not None: - rewards_eval.append((i, td_record["r_evaluation"])) + rewards_eval.append((i, td_record["r_evaluation"].item())) # if len(rewards_eval): # pbar.set_description(f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), reward eval: reward: {rewards_eval[-1][1]: 4.4f}") @@ -967,6 +1129,8 @@ def make_replay_buffer(make_replay_buffer=3): scheduler2.step() collector.shutdown() +del create_env_fn +del collector ############################################################################### # We can observe that using TD(lambda) made our results considerably more @@ -983,3 +1147,9 @@ def make_replay_buffer(make_replay_buffer=3): plt.ylabel("reward") plt.tight_layout() plt.title("TD-labmda DDPG results") + +# sphinx_gallery_start_ignore +import time + +time.sleep(10) +# sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 6b23a8462f6..f422ef2f273 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -227,7 +227,7 @@ def make_env(parallel=False, m=0, s=1): # in the input ``TensorDict``. -def make_model(): +def make_model(dummy_env): cnn_kwargs = { "num_cells": [32, 64, 64], "kernel_sizes": [6, 4, 3], @@ -243,7 +243,6 @@ def make_model(): 64, 64, ], - # "out_features": dummy_env.action_spec.shape[-1], "activation_class": nn.ELU, } net = DuelingCnnDQNet( @@ -293,7 +292,7 @@ def make_model(): actor_explore, params, params_target, -) = make_model() +) = make_model(dummy_env) params_flat = params.flatten_keys(".") params_target_flat = params_target.flatten_keys(".") @@ -379,7 +378,7 @@ def make_model(): if data["done"].any(): done = data["done"].squeeze(-1) - traj_lengths.append(data["step_count"][done].float().mean().item()) + traj_lengths.append(data["collector", "step_count"][done].float().mean().item()) # check that we have enough data to start training if sum(frames) > init_random_frames: @@ -484,6 +483,10 @@ def make_model(): # update policy weights data_collector.update_policy_weights_() +print("shutting down") +data_collector.shutdown() +del data_collector + if is_notebook(): display.clear_output(wait=True) display.display(plt.gcf()) @@ -547,7 +550,7 @@ def make_model(): actor_explore, params, params_target, -) = make_model() +) = make_model(dummy_env) params_flat = params.flatten_keys(".") params_target_flat = params_target.flatten_keys(".") @@ -614,7 +617,7 @@ def make_model(): if data["done"].any(): done = data["done"].squeeze(-1) - traj_lengths.append(data["step_count"][done].float().mean().item()) + traj_lengths.append(data["collector", "step_count"][done].float().mean().item()) if sum(frames) > init_random_frames: for _ in range(n_optim): @@ -726,6 +729,10 @@ def make_model(): # update policy weights data_collector.update_policy_weights_() +print("shutting down") +data_collector.shutdown() +del data_collector + if is_notebook(): display.clear_output(wait=True) display.display(plt.gcf()) @@ -851,7 +858,8 @@ def make_model(): dummy_env.transform.insert(0, CatTensors(["pixels"], "pixels_save", del_keys=False)) eval_rollout = dummy_env.rollout(max_steps=10000, policy=actor, auto_reset=True).cpu() -eval_rollout +print(eval_rollout) +del dummy_env ############################################################################### @@ -885,3 +893,9 @@ def make_model(): # - More fancy exploration techniques, such as NoisyLinear layers and such # (check ``torchrl.modules.NoisyLinear``, which is fully compatible with the # ``MLP`` class used in our Dueling DQN). + +# sphinx_gallery_start_ignore +import time + +time.sleep(10) +# sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index 9896d616fbc..57dd72544e7 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -72,7 +72,8 @@ ############################################################################### # Policy -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^ +# # We will design a policy where a backbone reads the "observation" key. # Then specific sub-components will ready the "observation_stand" and # "observation_walk" keys of the stacked tensordicts, if they are present, @@ -119,7 +120,8 @@ ############################################################################### # Executing diverse tasks in parallel -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# # We can parallelize the operations if the common keys-value pairs share the # same specs (in particular their shape and dtype must match: you can't do the # following if the observation shapes are different but are pointed to by the @@ -174,7 +176,6 @@ def env2_maker(): print(tdreset[0]) print(tdreset[1]) # should be different but all have an "action" key -############################################################################### env.step(tdreset) # computes actions and execute steps in parallel print(tdreset) @@ -183,7 +184,7 @@ def env2_maker(): ############################################################################### # Rollout -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^^ td_rollout = env.rollout(100, policy=seq, return_contiguous=False) @@ -194,3 +195,12 @@ def env2_maker(): ############################################################################### td_rollout[0] # tensordict of the first env: the stand obs is present + +env.close() +del env + +# sphinx_gallery_start_ignore +import time + +time.sleep(10) +# sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py new file mode 100644 index 00000000000..fccdcf8888a --- /dev/null +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -0,0 +1,710 @@ +# -*- coding: utf-8 -*- +""" +Writing your environment with TorchRL +===================================== +**Author**: `Vincent Moens `_ + +Creating an environment (a simulator or a interface to a physical control system) +is an integrative part of reinforcement learning and control engineering. + +TorchRL provides a set of tools to do this in multiple contexts. +This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` code a pendulum +simulator from the ground up. +It is freely inspired by the Pendulum-v1 implementation from `OpenAI-Gym/Farama-Gymnasium +control library `__. + +.. figure:: /_static/img/pendulum.gif + :alt: Pendulum + + Simple Pendulum + +Key learnings: + +- How to design an environment in TorchRL: + + - Specs (input, observation and reward); + - Methods: seeding, reset and step; +- Transforming your environment inputs and outputs; +- How to use :class:`tensordict.TensorDict` to carry arbitrary data structures + from sep to step. + +We will touch three crucial components of TorchRL: + +* `environments `__ +* `transforms `__ +* `models (policy and value function) `__ + +""" +###################################################################### +# To give a sense of what can be achieved with TorchRL's environments, we will +# be designing a stateless environment. While stateful environments keep track of +# the latest physical state encountered and rely on this to simulate the state-to-state +# transition, stateless environments expect the current state to be provided to +# them at each step, along with the action undertaken. +# +# Modelling stateless environments gives users full control over the input and +# outputs of the simulator: one can reset an experiment at any stage. It also +# assumes that we have some control over a task, which may not always be the case +# (solving a problem where we cannot control the current state is more challenging +# but has a much wider set of applications). +# +# Another advantage of stateless environments is that most of the time they allow +# for batched execution of transition simulations. If the backend and the +# implementation allow it, an algebraic operation can be executed seamlessly on +# scalars, vectors or tensors. This tutorial gives such examples. +# +# This tutorial will be structured as follows: +# +# * We will first get acquainted with the environment properties: +# its shape (``batch_size``), its methods (mainly `step`, `reset` and `set_seed`) +# and finally its specs. +# * After having coded our simulator, we will demonstrate how it can be used +# during training with transforms. +# * We will explore surprising new avenues that follow from the TorchRL's API, +# including: the possibility of transforming inputs, the vectorized execution +# of the simulation and the possibility of backpropagating through the +# simulation. +# * Finally, will train a simple policy to solve the system we implemented. +# +from collections import defaultdict +from typing import Optional + +import numpy as np +import torch +import tqdm +from tensordict.nn import TensorDictModule +from tensordict.tensordict import TensorDict, TensorDictBase +from torch import nn + +from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec +from torchrl.envs import ( + CatTensors, + Compose, + EnvBase, + TransformedEnv, + UnsqueezeTransform, +) +from torchrl.envs.utils import check_env_specs, step_mdp + +DEFAULT_X = np.pi +DEFAULT_Y = 1.0 + +###################################################################### +# There are four things one must take care of when designing a new environment +# class: ``_reset``, which codes for the resetting of the simulator at a (potentially +# random) initial state, ``_step`` which codes for the state transition dynamic, +# ``_set_seed`` which implements the seeding mechanism and finally the +# environment specs. +# +# :func:`_step` +# ~~~~~~~~~~~~~ +# +# The step method is the first thing to consider, as it will encode +# the simulation that is of interest to us. In TorchRL, the :class:`torchrl.envs.EnvBase` +# class has a :func:`EnvBase.step(tensordict)` method that receives a :class:`tensordict.TensorDict` +# instance with an ``"action"`` entry indicating what action is to be taken. +# To facilitate the reading and writing from that tensordict and to make sure +# that the keys are consistent with what's expected from the library, the +# simulation part has been delegated to a private abstract method :func:`_step` +# which reads input data from a tensordict, and writes a new tensordict with the +# output data. +# The :func:`_step` method should: +# +# 1. read the input keys (such as ``"action"``) and execute the simulation +# based on these; +# 2. retrieve observations, done state and reward; +# 3. write the set of observation value along with the reward and done state +# at the corresponding entries in a new :class:`TensorDict`. +# +# Next, the `step` method will rearrange this output and move the key-pair +# values of the observation in a new entry named ``"next"`` and leave the ``"reward"`` +# and ``"done"`` state at the root level. It will also run some sanity checks +# on the shapes of the tensordict content. +# +# Typically, this will look like +# +# .. code-block:: +# +# >>> print(tensordict) +# TensorDict(TODO) +# >>> env.step(tensordict) +# >>> print(tensordict) +# TensorDict(TODO) +# +# In the Pendulum example, our :func:`_step` method will read the relevant entries +# from the input tensordict and compute the position and velocity of the +# pendulum after the force encoded by the ``"action"`` key has been applied +# onto it. We compute the new angular position of the pendulum ``new_th`` as the result +# of the previous position ``th`` plus the new velocity ``new_thdot`` over a +# time interval ``dt``. Additionally, we pass the ``sin`` and ``cos`` of the +# angle to facilitate learning. +# +# Since our goal is to turn the pendulum up and maintain it still in that +# position, our ``cost`` (negative reward) function is lower for positions +# close to the target and low speeds. +# Indeed, we want to punish positions that are far from being "upward" +# and/or speeds that are far from 0. +# +# In our example, :func:`_step` is encoded as a static method since our +# environment is stateless. In stateful settings, the ``self`` argument is +# needed as the state needs to be read from the environment. +# + + +def _step(tensordict): + th, thdot = tensordict["th"], tensordict["thdot"] # th := theta + + g_force = tensordict["params", "g"] + mass = tensordict["params", "m"] + length = tensordict["params", "l"] + dt = tensordict["params", "dt"] + u = tensordict["action"].squeeze(-1) + u = u.clamp(-tensordict["params", "max_torque"], tensordict["params", "max_torque"]) + costs = angle_normalize(th) ** 2 + 0.1 * thdot**2 + 0.001 * (u**2) + + new_thdot = ( + thdot + + (3 * g_force / (2 * length) * th.sin() + 3.0 / (mass * length**2) * u) * dt + ) + new_thdot = new_thdot.clamp( + -tensordict["params", "max_speed"], tensordict["params", "max_speed"] + ) + new_th = th + new_thdot * dt + reward = -costs.view(*tensordict.shape, 1) + done = torch.zeros_like(reward, dtype=torch.bool) + out = TensorDict( + { + "th": new_th, + "sin": new_th.sin(), + "cos": new_th.cos(), + "thdot": new_thdot, + "params": tensordict["params"], + "reward": reward, + "done": done, + }, + tensordict.shape, + ) + return out + + +def angle_normalize(x): + return ((x + torch.pi) % (2 * torch.pi)) - torch.pi + + +###################################################################### +# :func:`_reset` +# ~~~~~~~~~~~~~ +# +# The second method we need to care about is the :func:`_reset` method. Like +# :func:`_step`, it should write the observation entries and possibly a done state +# in the tensordict it outputs. In some contexts, it is required that the `_reset` +# method receives a command from the function that called it (e.g. in multi-agent +# settings we may want to indicate which agent needs to be reset). This is +# why the :func:`_reset` method also expects a tensordict as input, albeit +# it may perfectly be empty. +# +# The parent :class:`EnvBase.reset` does some simple checks like the +# :class:`EnvBase.step` does, such as making sure that a ``"done"`` state +# is returned in the output tensordict and that the shapes match what is +# expected from the specs. +# +# For us, the only important thing to consider is whether +# :class:`EnvBase._reset` contains all the expected observations. Once more, +# since we are working with a stateless environment, we pass the configuration +# of the pendulum in a ``"params"`` nested tensordict. +# Comparing the output of :class:`EnvBase._reset` with :class:`EnvBase._step` +# +# We do not pass a done state as this is not mandatory for :func:`_reset` and +# our environment is non-terminating. +# + + +def _reset(self, tensordict): + if tensordict is None or tensordict.is_empty(): + # if no tensordict is passed, we generate a single set of hyperparameters + # Otherwise, we assume that the input tensordict contains all the relevant + # parameters to get started. + tensordict = self.gen_params(batch_size=self.batch_size) + + high_th = torch.tensor(DEFAULT_X, device=self.device) + high_thdot = torch.tensor(DEFAULT_Y, device=self.device) + low_th = -high_th + low_thdot = -high_thdot + + # for non batch-locked envs, the input tensordict shape dictates the number + # of simulators run simultaneously. In other contexts, the initial + # random state's shape will depend upon the environment batch-size instead. + th = ( + torch.rand(tensordict.shape, generator=self.rng, device=self.device) + * (high_th - low_th) + + low_th + ) + thdot = ( + torch.rand(tensordict.shape, generator=self.rng, device=self.device) + * (high_thdot - low_thdot) + + low_thdot + ) + out = TensorDict( + { + "th": th, + "sin": th.sin(), + "cos": th.cos(), + "thdot": thdot, + "params": tensordict["params"], + }, + batch_size=tensordict.shape, + ) + return out + + +###################################################################### +# Specs +# ~~~~~ +# +# The specs define the input and output domain of the environment. +# It is important that the specs accurately define the tensors that will be +# received at runtime, as they are often used to carry information about +# environments in multiprocessing and distributed settings. +# There four specs that we must code in our environment: +# +# * :obj:`EnvBase.observation_spec`: This will be a :class:`torchrl.data.CompositeSpec` +# instance where each key is an observation. +# * :obj:`EnvBase.action_spec`: It can be any type of spec, but it is required that it +# corresponds to the ``"action"`` entry in the input tensordict. +# * :obj:`EnvBase.input_spec`: contains all the input entries, +# including the :obj:`EnvBase.action_spec` (which is just a pointer to +# :obj:`EnvBase.input_spec['action_spec']`. As for :obj:`EnvBase.ObservationSpec`, +# it is expected that this spec is of type :obj:`torchrl.data.CompositeSpec`. +# to accommodate environments where multiple inputs are expected. +# * :obj:`EnvBase.reward_spec`: the reward spec have the particularity of +# having a singleton trailing dimension if the environment has an empty +# batch size. The reason is that we often pass observations in torch models +# that estimate a value estimate with non-empty shape: +# +# .. code-block:: +# +# >>> next_value = reward + (1 - done) * fun(observation) +# +# Working with *unsqueezed* rewards allows us to build algorithms that are +# not polluted with squeezing and unsqueezing operations. +# +# TorchRL offers multiple :class:`torchrl.data.TensorSpec` +# `subclasses `_ to +# encode the environment's input and output characteristics. +# +# *Specs shape*: The environment specs leading dimensions must match the +# environment batch-size. This is done to enforce that every component of an +# environment (including its transforms) have an accurate representation of +# the expected input and output shapes. This is something that should be +# accurately coded in stateful settings. +# +# For non batch-locked environments such as the one in our example (see below), +# this is irrelevant as the environment batch-size will most likely be empty. +# + + +def _make_spec(self, td_params): + self.observation_spec = CompositeSpec( + sin=BoundedTensorSpec(minimum=-1.0, maximum=1.0, shape=(), dtype=torch.float32), + cos=BoundedTensorSpec(minimum=-1.0, maximum=1.0, shape=(), dtype=torch.float32), + th=BoundedTensorSpec( + minimum=-torch.pi, + maximum=torch.pi, + shape=(), + dtype=torch.float32, + ), + thdot=BoundedTensorSpec( + minimum=-td_params["params", "max_speed"], + maximum=td_params["params", "max_speed"], + shape=(), + dtype=torch.float32, + ), + # we need to add the "params" to the observation specs, as we want + # to pass it at each step during a rollout + params=make_composite_from_td(td_params["params"]), + shape=(), + ) + # since the environment is stateless, we expect the previous output as input + self.input_spec = self.observation_spec.clone() + # action-spec will be automatically wrapped in input_spec, but the convenient + # self.action_spec = spec is supported + self.action_spec = BoundedTensorSpec( + minimum=-td_params["params", "max_torque"], + maximum=td_params["params", "max_torque"], + shape=(1,), + dtype=torch.float32, + ) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1)) + + +def make_composite_from_td(td): + # custom funtion to convert a tensordict in a similar spec structure + # of unbounded values. + composite = CompositeSpec( + { + key: make_composite_from_td(tensor) + if isinstance(tensor, TensorDictBase) + else UnboundedContinuousTensorSpec( + dtype=tensor.dtype, device=tensor.device, shape=tensor.shape + ) + for key, tensor in td.items() + }, + shape=td.shape, + ) + return composite + + +###################################################################### +# Seeding +# ~~~~~~~ +# +# Seeding an environment is a commong operation when initializing an experiment. +# :func:`EnvBase._set_seed` only goal is to set the seed of the contained +# simulator. If possible, this operation should not call `reset()` or interact +# with the environment execution. The parent :func:`EnvBase.set_seed` method +# incorporates a mechanism that allows seeding multiple environments with a +# different pseudo-random and reproducible seed. +# + + +def _set_seed(self, seed: Optional[int]): + rng = torch.manual_seed(seed) + self.rng = rng + + +###################################################################### +# Wrapping things together: the :class:`torchrl.envs.EnvBase` class +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can finally put together the pieces and design our environment class. +# The specs initialization needs to be performed during the environment +# construction so we must take care of calling the :func:`_make_spec` method +# within :func:`PendulumEnv.__init__`. +# +# We add a class method :func:`PendulumEnv.gen_params` which deterministically +# generates a set of hyperparameters to be used during execution. +# +# We define the environment as non-`batch-locked` by turning the homonymous +# attribute to ``False``. This means that we will not enforce the input +# tensordict to have a batch-size that matches the one of the environment +# + + +class PendulumEnv(EnvBase): + metadata = { + "render_modes": ["human", "rgb_array"], + "render_fps": 30, + } + batch_locked = False + + @classmethod + def gen_params(cls, g=10.0, batch_size=None) -> TensorDictBase: + if batch_size is None: + batch_size = [] + td = TensorDict( + { + "params": TensorDict( + { + "max_speed": 8, + "max_torque": 2.0, + "dt": 0.05, + "g": g, + "m": 1.0, + "l": 1.0, + }, + [], + ) + }, + [], + ) + if batch_size: + td = td.expand(batch_size).contiguous() + return td + + def __init__(self, td_params=None, seed=None, device="cpu"): + if td_params is None: + td_params = self.gen_params() + + super().__init__(device=device, batch_size=[]) + self._make_spec(td_params) + if seed is None: + seed = torch.empty((), dtype=torch.int64).random_().item() + self.set_seed(seed) + + _make_spec = _make_spec + _reset = _reset + _step = staticmethod(_step) + _set_seed = _set_seed + + +###################################################################### +# Testing our environment +# ----------------------- +# +# TorchRL provides a simple function :func:`torchrl.envs.utils.check_env_specs` +# to check that a (transformed) environment has an input/output structure that +# matches the one dictated by its specs. +# Let us try it out: +# + +env = PendulumEnv() +check_env_specs(env) + +###################################################################### +# We can have a look at our specs to have a visual representation of the environment +# signature +# +print("observation_spec:", env.observation_spec) +print("input_spec:", env.input_spec) +print("reward_spec:", env.reward_spec) + +###################################################################### +# We can execute a couple of commands too to check that the output structure +# matches what is expected. We can run the :func:`env.rand_step` to generate +# an action randomly from the ``action_spec`` domain: +# + +td = env.reset() +print("reset tensordict", td) +td = env.rand_step(td) +print("random step tensordict", td) + +###################################################################### +# Transforming an environment +# --------------------------- +# +# Writing environment transforms for stateless simulators is slightly more +# complicated than for stateful ones: transforming an output entry that needs +# to be read at the following iteration requires to apply the inverse transform +# before calling :func:`env.step` at the next step. +# For instance, in the following transformed environment we unsqueeze the entries +# ``["sin", "cos", "thdot"]`` to be able to stack them along the last +# dimension. We also pass them as ``in_keys_inv`` to squeeze them back to their +# original shape once they are passed as input in the next iteration. +# +env = TransformedEnv( + env, + Compose( + # Unsqueezes the observations that we will concatenate + UnsqueezeTransform( + unsqueeze_dim=-1, + in_keys=["sin", "cos", "thdot"], + in_keys_inv=["sin", "cos", "thdot"], + ), + # Concatenates the observations onto an "observation" entry. + # del_keys=False ensures that we keep these values for the next + # iteration. + CatTensors( + in_keys=["sin", "cos", "thdot"], out_key="observation", del_keys=False + ), + ), +) + +###################################################################### +# Executing a rollout +# ------------------- +# +# Executing a rollout is a succession of simple steps: +# +# * reset the environment +# * while some condition is not met: +# +# * compute an action given a policy +# * execute a step given this action +# * collect the data +# * make a MDP step +# +# * gather the data and return +# +# These operations have been convinently wrapped in the :func:`EnvBase.rollout` +# method, from which we provide a simplified version here below. + + +def simple_rollout(steps=100): + # preallocate: + data = TensorDict({}, [steps]) + # reset + _data = env.reset() + for i in range(steps): + _data["action"] = env.action_spec.rand() + _data = env.step(_data) + data[i] = _data + _data = step_mdp(_data, keep_other=True) + return data + + +print("data from rollout:", simple_rollout(100)) + +###################################################################### +# Batching computations +# --------------------- +# +# The last unexplored end of our tutorial is the ability that we have to +# batch computations in TorchRL. Because our environment does not +# make any assumptions regarding the input data shape, we can seamlessly +# execute it over batches of data. To do this, we just generate parameters +# with the desired shape. +# + +batch_size = 10 # number of environments to be executed in batch +td = env.reset(env.gen_params(batch_size=[batch_size])) +print("reset (batch size of 10)", td) +td = env.rand_step(td) +print("rand step (batch size of 10)", td) + +# executing a rollout with a batch of data requires us to reset the env +# out of the rollout function, since we need to define the batch_size +# dynamically and this is not supported by :func:`EnvBase.rollout`: +# + +rollout = env.rollout( + 3, + auto_reset=False, + tensordict=env.reset(env.gen_params(batch_size=[batch_size])), +) +print("rollout of len 3 (batch size of 10):", rollout) + + +###################################################################### +# Training a simple policy +# ------------------------ +# +# In this example, we will train a simple policy using the reward as a +# differentiable objective (i.e. a negative loss). +# We will take advantage of the fact that our dynamic system is fully +# differentiable to backpropagate through the trajectory return and adjust the +# weights of our policy to maximise this value directly. Of course, in many +# settings many of the assumptions we make do not hold, such as +# differentiability of the system and full access to the underlying mechanics. +# +# Still, this is a very simple example that showcases how a training loop can +# be coded with a custom environment in TorchRL. +# +# Let us first write the policy network: +# +torch.manual_seed(0) +env.set_seed(0) + +net = nn.Sequential( + nn.LazyLinear(64), + nn.Tanh(), + nn.LazyLinear(64), + nn.Tanh(), + nn.LazyLinear(64), + nn.Tanh(), + nn.LazyLinear(1), +) +policy = TensorDictModule( + net, + in_keys=["observation"], + out_keys=["action"], +) + +###################################################################### +# and our optimizer: +# + +optim = torch.optim.Adam(policy.parameters(), lr=2e-3) + +###################################################################### +# Finally, let us re-create our environment: +# + +env = TransformedEnv( + PendulumEnv(), + Compose( + UnsqueezeTransform( + unsqueeze_dim=-1, + in_keys=["sin", "cos", "thdot"], + in_keys_inv=["sin", "cos", "thdot"], + ), + CatTensors( + in_keys=["sin", "cos", "thdot"], out_key="observation", del_keys=False + ), + ), +) + +###################################################################### +# Training loop +# ~~~~~~~~~~~~~ +# +# We will successively: +# +# * generate a trajectory +# * sum the rewards +# * backpropagate through the graph defined by these operations +# * clip the gradient norm and make an optimization step +# * repeat +# +# At the end of the training loop, we should have a final reward close to 0 +# which demonstrates that the pendulum is upward and still as desired. +# +batch_size = 32 +pbar = tqdm.tqdm(range(20_000 // batch_size)) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 20_000) +logs = defaultdict(list) + +for _ in pbar: + init_td = env.reset(env.gen_params(batch_size=[batch_size])) + rollout = env.rollout(100, policy, tensordict=init_td, auto_reset=False) + traj_return = rollout["reward"].mean() + (-traj_return).backward() + gn = torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0) + optim.step() + optim.zero_grad() + pbar.set_description( + f"reward: {traj_return: 4.4f}, " + f"last reward: {rollout[..., -1]['reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}" + ) + logs["return"].append(traj_return.item()) + logs["last_reward"].append(rollout[..., -1]["reward"].mean().item()) + scheduler.step() + + +def plot(): + import matplotlib + from matplotlib import pyplot as plt + + is_ipython = "inline" in matplotlib.get_backend() + if is_ipython: + from IPython import display + + with plt.ion(): + plt.figure(figsize=(10, 5)) + plt.subplot(1, 2, 1) + plt.plot(logs["return"]) + plt.title("returns") + plt.xlabel("iteration") + plt.subplot(1, 2, 2) + plt.plot(logs["last_reward"]) + plt.title("last reward") + plt.xlabel("iteration") + if is_ipython: + display.display(plt.gcf()) + display.clear_output(wait=True) + plt.show() + + +plot() + +###################################################################### +# Conclusion +# ---------- +# +# In this tutorial, we have learned how to code a stateless environment from +# scratch. We touched the subjects of: +# +# * the four essential components that need to be taken care of when coding +# an environment (step, reset, seeding and building specs). We saw how these +# methods and classes interact with the :class:`tensordict.TensorDict` class; +# * how to test that an environment is properly coded using +# :func:`torchrl.envs.utils.check_env_specs`; +# * How to code transforms in the context of stateless environments; +# * How to train a policy on a fully differentiable simulator. +# + +# sphinx_gallery_start_ignore +import time + +time.sleep(10) +# sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/pretrained_models.py b/tutorials/sphinx-tutorials/pretrained_models.py index 73a69a60fce..9404b7abd43 100644 --- a/tutorials/sphinx-tutorials/pretrained_models.py +++ b/tutorials/sphinx-tutorials/pretrained_models.py @@ -37,10 +37,16 @@ # in the output tensordict. Our policy, consisting of a single layer MLP, will then read this vector and compute # the corresponding action. # -r3m = R3MTransform("resnet50", in_keys=["pixels"], download=True).to(device) +r3m = R3MTransform( + "resnet50", + in_keys=["pixels"], + download=True, +) env_transformed = TransformedEnv(base_env, r3m) net = nn.Sequential( - nn.LazyLinear(128), nn.Tanh(), nn.Linear(128, base_env.action_spec.shape[-1]) + nn.LazyLinear(128, device=device), + nn.Tanh(), + nn.Linear(128, base_env.action_spec.shape[-1], device=device), ) policy = Actor(net, in_keys=["r3m_vec"]) @@ -82,7 +88,7 @@ # from torchrl.data import LazyMemmapStorage, ReplayBuffer -storage = LazyMemmapStorage(1000) +storage = LazyMemmapStorage(1000, device=device) rb = ReplayBuffer(storage=storage, transform=r3m) ############################################################################## @@ -106,3 +112,9 @@ # batch = rb.sample(32) print("data after sampling:", batch) + +# sphinx_gallery_start_ignore +import time + +time.sleep(10) +# sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/tensordict_module.py b/tutorials/sphinx-tutorials/tensordict_module.py index 869e39f2a55..3d23662025d 100644 --- a/tutorials/sphinx-tutorials/tensordict_module.py +++ b/tutorials/sphinx-tutorials/tensordict_module.py @@ -833,3 +833,9 @@ def __init__( # ``TensorDictModule`` is marginal. # # Have fun with TensorDictModule! + +# sphinx_gallery_start_ignore +import time + +time.sleep(10) +# sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/tensordict_tutorial.py b/tutorials/sphinx-tutorials/tensordict_tutorial.py index ba37f5623e3..7a3d47e2e2e 100644 --- a/tutorials/sphinx-tutorials/tensordict_tutorial.py +++ b/tutorials/sphinx-tutorials/tensordict_tutorial.py @@ -494,3 +494,9 @@ def collate_dict_fn(dict_list): ############################################################################### # Have fun with TensorDict! # + +# sphinx_gallery_start_ignore +import time + +time.sleep(10) +# sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/torch_envs.py b/tutorials/sphinx-tutorials/torch_envs.py index 5f4f70172b2..18897c810b9 100644 --- a/tutorials/sphinx-tutorials/torch_envs.py +++ b/tutorials/sphinx-tutorials/torch_envs.py @@ -96,7 +96,7 @@ # we can just generate a random action: -def policy(tensordict): +def policy(tensordict, env=env): tensordict.set("action", env.action_spec.rand()) return tensordict @@ -266,7 +266,7 @@ def policy(tensordict): ############################################################################### # Transforming envs -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^^^^^^^^^^^^ # It is common to pre-process the output of an environment before having it # read by the policy or stored in a buffer. # @@ -460,6 +460,8 @@ def env_make(): out_seed = parallel_env.set_seed(10) print(out_seed) +del parallel_env + ############################################################################### # Accessing environment attributes # --------------------------------- @@ -518,6 +520,7 @@ def env_make(): ############################################################################### parallel_env.close() +del parallel_env ############################################################################### # kwargs for parallel environments @@ -527,8 +530,7 @@ def env_make(): ############################################################################### -from torchrl.envs import Compose, ParallelEnv, Resize, ToTensorImage, TransformedEnv -from torchrl.envs.libs.gym import GymEnv +from torchrl.envs import ParallelEnv def env_make(env_name): @@ -552,6 +554,7 @@ def env_make(env_name): plt.subplot(122) plt.imshow(tensordict[1].get("pixels").permute(1, 2, 0).numpy()) parallel_env.close() +del parallel_env from matplotlib import pyplot as plt @@ -572,7 +575,6 @@ def env_make(env_name): ToTensorImage, TransformedEnv, ) -from torchrl.envs.libs.gym import GymEnv def env_make(env_name): @@ -598,6 +600,7 @@ def env_make(env_name): plt.subplot(122) plt.imshow(tensordict[1].get("pixels").permute(1, 2, 0).numpy()) parallel_env.close() +del parallel_env ############################################################################### # VecNorm @@ -667,3 +670,10 @@ def env_make(env_name): ) env.close() +del env + +# sphinx_gallery_start_ignore +import time + +time.sleep(10) +# sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 73ed384887c..8f539e88e47 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -119,10 +119,10 @@ # other dependencies (gym, torchvision, wandb / tensorboard) are optional. # # Data -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^ # # TensorDict -# ------------------------------ +# ---------- # sphinx_gallery_start_ignore import warnings @@ -369,7 +369,10 @@ ############################################################################### -env.action_spec +print(env.action_spec) + +env.close() +del env ############################################################################### # Modules @@ -570,7 +573,7 @@ ############################################################################### # Using Environments and Modules -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ from torchrl.envs.utils import step_mdp @@ -624,36 +627,37 @@ ############################################################################### -# helper torch.manual_seed(0) env.set_seed(0) tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps) tensordict_rollout -############################################################################### (tensordict_rollout == tensordicts_prealloc).all() from tensordict.nn import TensorDictModule +############################################################################### # Collectors -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^^^^^ from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector -############################################################################### - from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv +############################################################################### # EnvCreator makes sure that we can send a lambda function from process to process + parallel_env = ParallelEnv(3, EnvCreator(lambda: GymEnv("Pendulum-v1"))) create_env_fn = [parallel_env, parallel_env] actor_module = nn.Linear(3, 1) actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"]) +############################################################################### # Sync data collector + devices = ["cpu", "cpu"] collector = MultiSyncDataCollector( @@ -696,10 +700,12 @@ print(i) collector.shutdown() del collector +del create_env_fn +del parallel_env ############################################################################### # Objectives -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^^^^^ # TorchRL delivers meta-RL compatible loss functions # Disclaimer: This APi may change in the future @@ -746,7 +752,7 @@ def forward(self, obs, action): ############################################################################### # State of the Library -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^^^^^^^^^^^^^^^ # # TorchRL is currently an **alpha-release**: there may be bugs and there is no # guarantee about BC-breaking changes. We should be able to move to a beta-release @@ -767,6 +773,12 @@ def forward(self, obs, action): ############################################################################### # Installing the Library -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ^^^^^^^^^^^^^^^^^^^^^^ # # The library is on PyPI: *pip install torchrl* + +# sphinx_gallery_start_ignore +import time + +time.sleep(10) +# sphinx_gallery_end_ignore