From 7051238709bc1afbb87a6de9b6824b7bfc6c6f65 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 13 Nov 2024 18:45:59 +0000 Subject: [PATCH] [Doc] Adding recurrent policies to export tutorial ghstack-source-id: 1f1af399b120db8bbb1789748641f44fd3b1bd5e Pull Request resolved: https://github.com/pytorch/rl/pull/2559 --- .github/workflows/docs.yml | 2 +- docs/source/conf.py | 3 + test/test_tensordictmodules.py | 65 +++++++++ torchrl/_utils.py | 24 +++- torchrl/modules/tensordict_module/rnn.py | 108 +++++++++++++- tutorials/sphinx-tutorials/export.py | 173 +++++++++++++++++------ 6 files changed, 328 insertions(+), 47 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index e153641e775..97fae17a8d8 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -93,7 +93,7 @@ jobs: cd ./docs # timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi # bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi - PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build + PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build cd .. cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}" diff --git a/docs/source/conf.py b/docs/source/conf.py index 0a7781b8675..35f5e5c3882 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -49,6 +49,8 @@ version = f"main ({torchrl.__version__})" release = "main" +os.environ["TORCHRL_CONSOLE_STREAM"] = "stdout" + # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # @@ -95,6 +97,7 @@ "abort_on_example_error": False, "only_warn_on_example_error": True, "show_memory": True, + "capture_repr": ("_repr_html_", "__repr__"), # capture representations } napoleon_use_ivar = True diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 05e46ce0ecd..ec9322500b4 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -8,6 +8,8 @@ import pytest import torch + +import torchrl.modules from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from torch import nn @@ -743,6 +745,41 @@ def test_set_temporal_mode(self): lstm_module.parameters() ) + def test_python_cudnn(self): + lstm_module = LSTMModule( + input_size=3, + hidden_size=12, + batch_first=True, + dropout=0, + num_layers=2, + in_keys=["observation", "hidden0", "hidden1"], + out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], + ).set_recurrent_mode(True) + obs = torch.rand(10, 20, 3) + + hidden0 = torch.rand(10, 20, 2, 12) + hidden1 = torch.rand(10, 20, 2, 12) + + is_init = torch.zeros(10, 20, dtype=torch.bool) + assert isinstance(lstm_module.lstm, nn.LSTM) + outs_ref = lstm_module( + observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init + ) + + lstm_module.make_python_based() + assert isinstance(lstm_module.lstm, torchrl.modules.LSTM) + outs_rl = lstm_module( + observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init + ) + torch.testing.assert_close(outs_ref, outs_rl) + + lstm_module.make_cudnn_based() + assert isinstance(lstm_module.lstm, nn.LSTM) + outs_cudnn = lstm_module( + observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init + ) + torch.testing.assert_close(outs_ref, outs_cudnn) + def test_noncontiguous(self): lstm_module = LSTMModule( input_size=3, @@ -1088,6 +1125,34 @@ def test_set_temporal_mode(self): gru_module.parameters() ) + def test_python_cudnn(self): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + dropout=0, + num_layers=2, + in_keys=["observation", "hidden0"], + out_keys=["intermediate", ("next", "hidden0")], + ).set_recurrent_mode(True) + obs = torch.rand(10, 20, 3) + + hidden0 = torch.rand(10, 20, 2, 12) + + is_init = torch.zeros(10, 20, dtype=torch.bool) + assert isinstance(gru_module.gru, nn.GRU) + outs_ref = gru_module(observation=obs, hidden0=hidden0, is_init=is_init) + + gru_module.make_python_based() + assert isinstance(gru_module.gru, torchrl.modules.GRU) + outs_rl = gru_module(observation=obs, hidden0=hidden0, is_init=is_init) + torch.testing.assert_close(outs_ref, outs_rl) + + gru_module.make_cudnn_based() + assert isinstance(gru_module.gru, nn.GRU) + outs_cudnn = gru_module(observation=obs, hidden0=hidden0, is_init=is_init) + torch.testing.assert_close(outs_ref, outs_cudnn) + def test_noncontiguous(self): gru_module = GRUModule( input_size=3, diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 31e00614fd9..f2ce0cf520e 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -40,7 +40,17 @@ # Remove all attached handlers while logger.hasHandlers(): logger.removeHandler(logger.handlers[0]) -console_handler = logging.StreamHandler() +stream_handlers = { + "stdout": sys.stdout, + "stderr": sys.stderr, +} +TORCHRL_CONSOLE_STREAM = os.getenv("TORCHRL_CONSOLE_STREAM") +if TORCHRL_CONSOLE_STREAM: + stream_handler = stream_handlers[TORCHRL_CONSOLE_STREAM] +else: + stream_handler = None +console_handler = logging.StreamHandler(stream=stream_handler) + console_handler.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s [%(name)s][%(levelname)s] %(message)s") console_handler.setFormatter(formatter) @@ -86,9 +96,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): val[2] = N @staticmethod - def print(prefix=None): # noqa: T202 + def print(prefix=None) -> str: # noqa: T202 + """Prints the state of the timer. + + Returns: + the string printed using the logger. + """ keys = list(timeit._REG) keys.sort() + string = [] for name in keys: strings = [] if prefix: @@ -96,7 +112,9 @@ def print(prefix=None): # noqa: T202 strings.append( f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)" ) - logger.info(" -- ".join(strings)) + string.append(" -- ".join(strings)) + logger.info(string[-1]) + return "\n".join(string) @classmethod def todict(cls, percall=True): diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index f538f8e95c5..cf210985613 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -12,7 +12,7 @@ from tensordict.base import NO_DEFAULT -from tensordict.nn import TensorDictModuleBase as ModuleBase +from tensordict.nn import dispatch, TensorDictModuleBase as ModuleBase from tensordict.utils import expand_as_right, prod, set_lazy_legacy from torch import nn, Tensor @@ -467,6 +467,8 @@ def __init__( raise ValueError("The input lstm must have batch_first=True.") if bidirectional: raise ValueError("The input lstm cannot be bidirectional.") + if not hidden_size: + raise ValueError("hidden_size must be passed.") if python_based: lstm = LSTM( input_size=input_size, @@ -524,6 +526,58 @@ def __init__( self.out_keys = out_keys self._recurrent_mode = False + def make_python_based(self) -> LSTMModule: + """Transforms the LSTM layer in its python-based version. + + Returns: + self + + """ + if isinstance(self.lstm, LSTM): + return self + lstm = LSTM( + input_size=self.lstm.input_size, + hidden_size=self.lstm.hidden_size, + num_layers=self.lstm.num_layers, + bias=self.lstm.bias, + dropout=self.lstm.dropout, + proj_size=self.lstm.proj_size, + device="meta", + batch_first=self.lstm.batch_first, + bidirectional=self.lstm.bidirectional, + ) + from tensordict import from_module + + from_module(self.lstm).to_module(lstm) + self.lstm = lstm + return self + + def make_cudnn_based(self) -> LSTMModule: + """Transforms the LSTM layer in its CuDNN-based version. + + Returns: + self + + """ + if isinstance(self.lstm, nn.LSTM): + return self + lstm = nn.LSTM( + input_size=self.lstm.input_size, + hidden_size=self.lstm.hidden_size, + num_layers=self.lstm.num_layers, + bias=self.lstm.bias, + dropout=self.lstm.dropout, + proj_size=self.lstm.proj_size, + device="meta", + batch_first=self.lstm.batch_first, + bidirectional=self.lstm.bidirectional, + ) + from tensordict import from_module + + from_module(self.lstm).to_module(lstm) + self.lstm = lstm + return self + def make_tensordict_primer(self): """Makes a tensordict primer for the environment. @@ -644,6 +698,7 @@ def set_recurrent_mode(self, mode: bool = True): out._recurrent_mode = mode return out + @dispatch def forward(self, tensordict: TensorDictBase): # we want to get an error if the value input is missing, but not the hidden states defaults = [NO_DEFAULT, None, None] @@ -1273,6 +1328,56 @@ def __init__( self.out_keys = out_keys self._recurrent_mode = False + def make_python_based(self) -> GRUModule: + """Transforms the GRU layer in its python-based version. + + Returns: + self + + """ + if isinstance(self.gru, GRU): + return self + gru = GRU( + input_size=self.gru.input_size, + hidden_size=self.gru.hidden_size, + num_layers=self.gru.num_layers, + bias=self.gru.bias, + dropout=self.gru.dropout, + device="meta", + batch_first=self.gru.batch_first, + bidirectional=self.gru.bidirectional, + ) + from tensordict import from_module + + from_module(self.gru).to_module(gru) + self.gru = gru + return self + + def make_cudnn_based(self) -> GRUModule: + """Transforms the GRU layer in its CuDNN-based version. + + Returns: + self + + """ + if isinstance(self.gru, nn.GRU): + return self + gru = nn.GRU( + input_size=self.gru.input_size, + hidden_size=self.gru.hidden_size, + num_layers=self.gru.num_layers, + bias=self.gru.bias, + dropout=self.gru.dropout, + device="meta", + batch_first=self.gru.batch_first, + bidirectional=self.gru.bidirectional, + ) + from tensordict import from_module + + from_module(self.gru).to_module(gru) + self.gru = gru + return self + def make_tensordict_primer(self): """Makes a tensordict primer for the environment. @@ -1389,6 +1494,7 @@ def set_recurrent_mode(self, mode: bool = True): out._recurrent_mode = mode return out + @dispatch @set_lazy_legacy(False) def forward(self, tensordict: TensorDictBase): # we want to get an error if the value input is missing, but not the hidden states diff --git a/tutorials/sphinx-tutorials/export.py b/tutorials/sphinx-tutorials/export.py index 19270658e2e..af8627264bb 100644 --- a/tutorials/sphinx-tutorials/export.py +++ b/tutorials/sphinx-tutorials/export.py @@ -15,48 +15,44 @@ !pip install torchrl !pip install "gymnasium[atari,accept-rom-license]"<1.0.0 +Introduction +------------ + +Learning a policy has little value if that policy cannot be deployed in real-world settings. +As shown in other tutorials, TorchRL has a strong focus on modularity and composability: thanks to ``tensordict``, +the components of the library can be written in the most generic way there is by abstracting their signature to a +mere set of operations on an input ``TensorDict``. +This may give the impression that the library is bound to be used only for training, as typical low-level execution +hardwares (edge devices, robots, arduino, Raspberry Pi) do not execute python code, let alone with pytorch, tensordict +or torchrl installed. + +Fortunately, PyTorch provides a full ecosystem of solutions to export code and trained models to devices and +hardwares, and TorchRL is fully equipped to interact with it. +It is possible to choose from a varied set of backends, including ONNX or AOTInductor examplified in this tutorial. +This tutorial gives a quick overview of how a trained model can be isolated and shipped as a standalone executable +to be exported on hardware. + +Key learnings: + +- Export any TorchRL module after training; +- Using various backends; +- Testing your exported model. + +Fast recap: a simple TorchRL training loop +------------------------------------------ + +In this section, we reproduce the training loop from the last Getting Started tutorial, slightly adapted to be used +with Atari games as they are rendered by the gymnasium library. +We will stick to the DQN example, and show how a policy that outputs a distribution over values can be used instead +later. + """ -import tempfile import time from pathlib import Path import numpy as np import tensordict.utils -################################ -# Introduction -# ------------ -# -# Learning a policy has little value if that policy cannot be deployed in real-world settings. -# As shown in other tutorials, TorchRL has a strong focus on modularity and composability: thanks to ``tensordict``, -# the components of the library can be written in the most generic way there is by abstracting their signature to a -# mere set of operations on an input ``TensorDict``. -# This may give the impression that the library is bound to be used only for training, as typical low-level execution -# hardwares (edge devices, robots, arduino, Raspberry Pi) do not execute python code, let alone with pytorch, tensordict -# or torchrl installed. -# -# Fortunately, PyTorch provides a full ecosystem of solutions to export code and trained models to devices and -# hardwares, and TorchRL is fully equipped to interact with it. -# It is possible to choose from a varied set of backends, including ONNX, TODO X and Y -# This tutorial gives a quick overview of how a trained model can be isolated and shipped as a standalone executable -# to be exported on hardware. -# -# Key learnings: -# -# - Export any TorchRL module after training; -# - Using various backends; -# - Testing your exported model. -# -# Fast recap: a simple TorchRL training loop -# ------------------------------------------ -# -# In this section, we reproduce the training loop from the last Getting Started tutorial, slightly adapted to be used -# with Atari games as they are rendered by the gymnasium library. -# We will stick to the DQN example, and show how a policy that outputs a distribution over values can be used instead -# later. -# - - import torch from tensordict.nn import ( @@ -175,11 +171,15 @@ # We create a fake input, and pass it to :func:`~torch.export.export` with the policy. This will give a "raw" python # function that will read our input tensor and output an action without any reference to TorchRL or tensordict modules. # +# A good practice is to call :meth:`~tensordict.nn.TensorDictSequential.select_out_keys` to let the model know that +# we only want a certain set of outputs (in case the policy returns more than one tensor). +# fake_td = env.base_env.fake_tensordict() pixels = fake_td["pixels"] with set_exploration_type("DETERMINISTIC"): exported_policy = torch.export.export( + # Select only the "action" output key policy_transform.select_out_keys("action"), args=(), kwargs={"pixels": pixels}, @@ -213,7 +213,8 @@ # characters such as spaces or commas. # # ``torch.export`` has many other features that we will explore further below. Before this, let us just do a small -# digression on exploration and stochastic policies in the context of test-time inference. +# digression on exploration and stochastic policies in the context of test-time inference, as well as recurrent +# policies. # # Working with stochastic policies # -------------------------------- @@ -243,8 +244,89 @@ print("Stochastic policy") exported_stochastic_policy.graph_module.print_readable() +##################################### +# Working with recurrent policies +# ------------------------------- +# +# Another typical use case is a recurrent policy that will output an action as well as a one or more recurrent state. +# LSTM and GRU are CuDNN-based modules, which means that they will behave differently than regular +# :class:`~torch.nn.Module` instances (export utils may not trace them well). Fortunately, TorchRL provides a python +# implementation of these modules that can be swapped with the CuDNN version when desired. +# +# To show this, let us write a prototypical policy that relies on an RNN: +# +from tensordict.nn import TensorDictModule +from torchrl.envs import BatchSizeTransform +from torchrl.modules import LSTMModule, MLP + +lstm = LSTMModule( + input_size=32, + num_layers=2, + hidden_size=256, + in_keys=["observation", "hidden0", "hidden1"], + out_keys=["intermediate", "hidden0", "hidden1"], +) +##################################### +# We set the recurrent mode to ``False`` to allow the module to read inputs one-by-one and not in batch. +# +lstm = lstm.set_recurrent_mode(False) + +##################################### +# If the LSTM module is not python based but CuDNN (:class:`~torch.nn.LSTM`), the :meth:`~torchrl.modules.LSTMModule.make_python_based` +# method can be used to use the python version. +# +lstm = lstm.make_python_based() -from tempfile import TemporaryDirectory +##################################### +# Let's now create the policy. We combine two layers that modify the shape of the input (unsqueeze/squeeze operations) +# with the LSTM and an MLP. +# + +recurrent_policy = TensorDictSequential( + # Unsqueeze the first dim of all tensors to make LSTMCell happy + BatchSizeTransform(reshape_fn=lambda x: x.unsqueeze(0)), + lstm, + TensorDictModule( + MLP(in_features=256, out_features=5, num_cells=[64, 64]), + in_keys=["intermediate"], + out_keys=["action"], + ), + # Squeeze the first dim of all tensors to get the original shape back + BatchSizeTransform(reshape_fn=lambda x: x.squeeze(0)), +) + +##################################### +# As before, we select the relevant keys: +# + +recurrent_policy.select_out_keys("action", "hidden0", "hidden1") +print("recurrent policy input keys:", recurrent_policy.in_keys) +print("recurrent policy output keys:", recurrent_policy.out_keys) + +##################################### +# We are now ready to export. To do this, we build fake inputs and pass them to :func:`~torch.export.export`: +# + +fake_obs = torch.randn(32) +fake_hidden0 = torch.randn(2, 256) +fake_hidden1 = torch.randn(2, 256) + +# Tensor indicating whether the state is the first of a sequence +fake_is_init = torch.zeros((), dtype=torch.bool) + +exported_recurrent_policy = torch.export.export( + recurrent_policy, + args=(), + kwargs={ + "observation": fake_obs, + "hidden0": fake_hidden0, + "hidden1": fake_hidden1, + "is_init": fake_is_init, + }, + strict=False, +) +print("Recurrent policy graph:") +exported_recurrent_policy.graph_module.print_readable() ##################################### # AOTInductor: Export your policy to pytorch-free C++ binaries @@ -257,6 +339,8 @@ # `AOTI documentation `_: # +from tempfile import TemporaryDirectory + from torch._inductor import aoti_compile_and_package, aoti_load_package with TemporaryDirectory() as tmpdir: @@ -269,10 +353,9 @@ # Specify the generated shared library path package_path=path, ) + print("pkg_path", pkg_path) - compiled_module = aoti_load_package(str(Path(tmpdir) / "model.pt2")) - # Print the structor of our temporary directory, including file size - tensordict.utils.print_directory_tree(tmpdir) + compiled_module = aoti_load_package(pkg_path) print(compiled_module(pixels=pixels)) @@ -321,6 +404,12 @@ screen_obs = ale.getScreenRGB() print("Observation from ALE simulator:", type(screen_obs), screen_obs.shape) +from matplotlib import pyplot as plt + +plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) +plt.imshow(screen_obs) +plt.title("Screen rendering of Pong game.") + ##################################### # Exporting to ONNX is quite similar the Export/AOTI above: # @@ -334,7 +423,7 @@ ##################################### # We can now save the program on disk and load it: -with tempfile.TemporaryDirectory() as tmpdir: +with TemporaryDirectory() as tmpdir: onnx_file_path = str(Path(tmpdir) / "policy.onnx") onnx_policy_export.save(onnx_file_path) @@ -371,7 +460,7 @@ def onnx_policy(screen_obs: np.ndarray) -> int: with timeit("TorchRL version"), torch.no_grad(), set_exploration_type("DETERMINISTIC"): env.rollout(num_steps, policy_explore) -timeit.print() +print(timeit.print()) ##################################### # Note that ONNX also offers the possibility of optimizing models directly, but this is beyond the scope of this