diff --git a/docs/requirements.txt b/docs/requirements.txt index 702a2884421..e212cd942f4 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -25,3 +25,6 @@ memory_profiler pyrender pytest vmas +onnxscript +onnxruntime +onnx diff --git a/docs/source/index.rst b/docs/source/index.rst index 5f3882faed7..0e034ef2351 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -104,6 +104,7 @@ Intermediate tutorials/pretrained_models tutorials/dqn_with_rnn tutorials/rb_tutorial + tutorials/export Advanced -------- diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 3faaa396299..10884f8d415 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -201,6 +201,9 @@ def __init__( if not isinstance(out_features, Number): _out_features_num = prod(out_features) self.out_features = out_features + self._reshape_out = not isinstance( + self.out_features, (int, torch.SymInt, Number) + ) self._out_features_num = _out_features_num self.activation_class = activation_class self.norm_class = norm_class @@ -302,7 +305,7 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: inputs = (torch.cat([*inputs], -1),) out = super().forward(*inputs) - if not isinstance(self.out_features, Number): + if self._reshape_out: out = out.view(*out.shape[:-1], *self.out_features) return out @@ -549,6 +552,27 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: out = out.unflatten(0, batch) return out + @classmethod + def default_atari_dqn(cls, num_actions: int): + """Returns the default DQN as presented in the seminal DQN paper. + + Args: + num_actions (int): the action space of the atari game. + + """ + cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + mlp = MLP( + activation_class=torch.nn.ReLU, + out_features=num_actions, + num_cells=[512], + ) + return nn.Sequential(cnn, mlp) + Conv2dNet = ConvNet diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 7337d1c94dd..cfe81abee44 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -152,6 +152,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: out = action_tensordict.get(action_key) eps = self.eps.item() cond = torch.rand(action_tensordict.shape, device=out.device) < eps + # cond = torch.zeros(action_tensordict.shape, device=out.device, dtype=torch.bool).bernoulli_(eps) cond = expand_as_right(cond, out) spec = self.spec if spec is not None: diff --git a/tutorials/sphinx-tutorials/export.py b/tutorials/sphinx-tutorials/export.py new file mode 100644 index 00000000000..19270658e2e --- /dev/null +++ b/tutorials/sphinx-tutorials/export.py @@ -0,0 +1,402 @@ +""" +Exporting TorchRL modules +========================= + +**Author**: `Vincent Moens `_ + +.. _export_tuto: + +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + !pip install "gymnasium[atari,accept-rom-license]"<1.0.0 + +""" +import tempfile +import time +from pathlib import Path + +import numpy as np +import tensordict.utils + +################################ +# Introduction +# ------------ +# +# Learning a policy has little value if that policy cannot be deployed in real-world settings. +# As shown in other tutorials, TorchRL has a strong focus on modularity and composability: thanks to ``tensordict``, +# the components of the library can be written in the most generic way there is by abstracting their signature to a +# mere set of operations on an input ``TensorDict``. +# This may give the impression that the library is bound to be used only for training, as typical low-level execution +# hardwares (edge devices, robots, arduino, Raspberry Pi) do not execute python code, let alone with pytorch, tensordict +# or torchrl installed. +# +# Fortunately, PyTorch provides a full ecosystem of solutions to export code and trained models to devices and +# hardwares, and TorchRL is fully equipped to interact with it. +# It is possible to choose from a varied set of backends, including ONNX, TODO X and Y +# This tutorial gives a quick overview of how a trained model can be isolated and shipped as a standalone executable +# to be exported on hardware. +# +# Key learnings: +# +# - Export any TorchRL module after training; +# - Using various backends; +# - Testing your exported model. +# +# Fast recap: a simple TorchRL training loop +# ------------------------------------------ +# +# In this section, we reproduce the training loop from the last Getting Started tutorial, slightly adapted to be used +# with Atari games as they are rendered by the gymnasium library. +# We will stick to the DQN example, and show how a policy that outputs a distribution over values can be used instead +# later. +# + + +import torch + +from tensordict.nn import ( + TensorDictModule as Mod, + TensorDictSequential, + TensorDictSequential as Seq, +) + +from torch.optim import Adam + +from torchrl._utils import timeit +from torchrl.collectors import SyncDataCollector +from torchrl.data import LazyTensorStorage, ReplayBuffer + +from torchrl.envs import ( + Compose, + GrayScale, + GymEnv, + Resize, + set_exploration_type, + StepCounter, + ToTensorImage, + TransformedEnv, +) + +from torchrl.modules import ConvNet, EGreedyModule, QValueModule + +from torchrl.objectives import DQNLoss, SoftUpdate + +torch.manual_seed(0) + +env = TransformedEnv( + GymEnv("ALE/Pong-v5", categorical_action_encoding=True), + Compose( + ToTensorImage(), Resize(84, interpolation="nearest"), GrayScale(), StepCounter() + ), +) +env.set_seed(0) + +value_mlp = ConvNet.default_atari_dqn(num_actions=env.action_spec.space.n) +value_net = Mod(value_mlp, in_keys=["pixels"], out_keys=["action_value"]) +policy = Seq(value_net, QValueModule(spec=env.action_spec)) +exploration_module = EGreedyModule( + env.action_spec, annealing_num_steps=100_000, eps_init=0.5 +) +policy_explore = Seq(policy, exploration_module) + +init_rand_steps = 5000 +frames_per_batch = 100 +optim_steps = 10 +collector = SyncDataCollector( + env, + policy_explore, + frames_per_batch=frames_per_batch, + total_frames=-1, + init_random_frames=init_rand_steps, +) +rb = ReplayBuffer(storage=LazyTensorStorage(100_000)) + +loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True) +optim = Adam(loss.parameters()) +updater = SoftUpdate(loss, eps=0.99) + +total_count = 0 +total_episodes = 0 +t0 = time.time() +for data in collector: + # Write data in replay buffer + rb.extend(data) + max_length = rb[:]["next", "step_count"].max() + if len(rb) > init_rand_steps: + # Optim loop (we do several optim steps + # per batch collected for efficiency) + for _ in range(optim_steps): + sample = rb.sample(128) + loss_vals = loss(sample) + loss_vals["loss"].backward() + optim.step() + optim.zero_grad() + # Update exploration factor + exploration_module.step(data.numel()) + # Update target params + updater.step() + total_count += data.numel() + total_episodes += data["next", "done"].sum() + if max_length > 200: + break + +##################################### +# Exporting a TensorDictModule-based policy +# ----------------------------------------- +# +# ``TensorDict`` allowed us to build a policy with a great flexibility: from a regular :class:`~torch.nn.Module` that +# outputs action values from an observation, we added a :class:`~torchrl.modules.QValueModule` module that +# read these values and computed an action using some heuristic (e.g., an argmax call). +# +# However, there's a small technical catch in our case: the environment (the actual Atari game) doesn't return +# grayscale, 84x84 images but raw screen-size color ones. The transforms we appended to the environment make sure that +# the images can be read by the model. We can see that, from the training perspective, the boundary between environment +# and model is blurry, but at execution time things are much clearer: the model should take care of transforming +# the input data (images) to the format that can be processed by our CNN. +# +# Here again, the magic of tensordict will unblock us: it happens that most of local (non-recursive) TorchRL's +# transforms can be used both as environment transforms or preprocessing blocks within a :class:`~torch.nn.Module` +# instance. Let's see how we can prepend them to our policy: + +policy_transform = TensorDictSequential( + env.transform[ + :-1 + ], # the last transform is a step counter which we don't need for preproc + policy_explore.requires_grad_( + False + ), # Using the explorative version of the policy for didactic purposes, see below. +) +##################################### +# We create a fake input, and pass it to :func:`~torch.export.export` with the policy. This will give a "raw" python +# function that will read our input tensor and output an action without any reference to TorchRL or tensordict modules. +# + +fake_td = env.base_env.fake_tensordict() +pixels = fake_td["pixels"] +with set_exploration_type("DETERMINISTIC"): + exported_policy = torch.export.export( + policy_transform.select_out_keys("action"), + args=(), + kwargs={"pixels": pixels}, + strict=False, + ) + +##################################### +# Representing the policy can be quite insightful: we can see that the first operations are a permute, a div, unsqueeze, +# resize followed by the convolutional and MLP layers. +# +print("Deterministic policy") +exported_policy.graph_module.print_readable() + +##################################### +# As a final check, we can execute the policy with a dummy input. The output (for a single image) should be an integer +# from 0 to 6 representing the action to be executed in the game. + +output = exported_policy.module()(pixels=pixels) +print("Exported module output", output) + +##################################### +# Further details on exporting :class:`~tensordict.nn.TensorDictModule` instances can be found in the tensordict +# `documentation `_. +# +# .. note:: +# Exporting modules that take and output nested keys is perfectly fine. +# The corresponding kwargs will be the `"_".join(key)` version of the key, i.e., the `("group0", "agent0", "obs")` +# key will correspond to the `"group0_agent0_obs"` keyword argument. Colliding keys (e.g., `("group0_agent0", "obs")` +# and `("group0", "agent0_obs")` may lead to undefined behaviours and should be avoided at all cost. +# Obviously, key names should also always produce valid keyword arguments, i.e., they should not contain special +# characters such as spaces or commas. +# +# ``torch.export`` has many other features that we will explore further below. Before this, let us just do a small +# digression on exploration and stochastic policies in the context of test-time inference. +# +# Working with stochastic policies +# -------------------------------- +# +# As you probably noted, above we used the :class:`~torchrl.envs.set_exploration_type` context manager to control +# the behaviour of the policy. If the policy is stochastic (e.g., the policy outputs a distribution over the action +# space like it is the case in PPO or other similar on-policy algorithms) or explorative (with an exploration module +# appended like E-Greedy, additive gaussian or Ornstein-Uhlenbeck) we may want or not want to use that exploration +# strategy in its exported version. +# Fortunately, export utils can understand that context manager and as long as the exportation occurs within the right +# context manager, the behaviour of the policy should match what is indicated. To demonstrate this, let us try with +# another exploration type: + +with set_exploration_type("RANDOM"): + exported_stochastic_policy = torch.export.export( + policy_transform.select_out_keys("action"), + args=(), + kwargs={"pixels": pixels}, + strict=False, + ) + +##################################### +# Our exported policy should now have a random module at the end of the call stack, unlike the previous version. +# Indeed, the last three operations are: generate a random integer between 0 and 6, use a random mask and select +# the network output or the random action based on the value in the mask. +# +print("Stochastic policy") +exported_stochastic_policy.graph_module.print_readable() + + +from tempfile import TemporaryDirectory + +##################################### +# AOTInductor: Export your policy to pytorch-free C++ binaries +# ------------------------------------------------------------ +# +# AOTInductor is a PyTorch module that allows you to export your model (policy or other) to pytorch-free C++ binaries. +# This is particularly useful when you need to deploy your model on devices or platforms where PyTorch is not available. +# +# Here's an example of how you can use AOTInductor to export your policy, inspired by the +# `AOTI documentation `_: +# + +from torch._inductor import aoti_compile_and_package, aoti_load_package + +with TemporaryDirectory() as tmpdir: + path = str(Path(tmpdir) / "model.pt2") + with torch.no_grad(): + pkg_path = aoti_compile_and_package( + exported_policy, + args=(), + kwargs={"pixels": pixels}, + # Specify the generated shared library path + package_path=path, + ) + + compiled_module = aoti_load_package(str(Path(tmpdir) / "model.pt2")) + # Print the structor of our temporary directory, including file size + tensordict.utils.print_directory_tree(tmpdir) + +print(compiled_module(pixels=pixels)) + +##################################### +# Exporting TorchRL models with ONNX +# ---------------------------------- +# +# .. note:: To execute this part of the script, make sure pytorch onnx is installed: +# +# +# .. code-block:: +# +# !pip install onnx-pytorch +# !pip install onnxruntime +# +# You can also find more information about using ONNX in the PyTorch ecosystem +# `here `_. The following example is based on this +# documentation. +# +# In this section, we are going to showcase how we can export our model in such a way that it can be +# executed on a pytorch-free setting. +# +# There are plenty of resources on the web explaining how ONNX can be used to deploy PyTorch models on various +# hardwares and devices, including `Raspberry Pi `_, +# `NVIDIA TensorRT `_, +# `iOS `_ and +# `Android `_. +# +# The Atari game we trained on can be isolated without TorchRL or gymnasium with the +# `ALE library `_ and therefore provides us with +# a good example of what we can achieve with ONNX. +# +# Let us see what this API looks like: + +from ale_py import ALEInterface, roms + +# Create the interface +ale = ALEInterface() +# Load the pong environment +ale.loadROM(roms.Pong) +ale.reset_game() + +# Make a step in the simulator +action = 0 +reward = ale.act(action) +screen_obs = ale.getScreenRGB() +print("Observation from ALE simulator:", type(screen_obs), screen_obs.shape) + +##################################### +# Exporting to ONNX is quite similar the Export/AOTI above: +# + +import onnxruntime + +with set_exploration_type("DETERMINISTIC"): + # We use torch.onnx.dynamo_export to capture the computation graph from our policy_explore model + pixels = torch.as_tensor(screen_obs) + onnx_policy_export = torch.onnx.dynamo_export(policy_transform, pixels=pixels) + +##################################### +# We can now save the program on disk and load it: +with tempfile.TemporaryDirectory() as tmpdir: + onnx_file_path = str(Path(tmpdir) / "policy.onnx") + onnx_policy_export.save(onnx_file_path) + + ort_session = onnxruntime.InferenceSession( + onnx_file_path, providers=["CPUExecutionProvider"] + ) + +onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs} +onnx_policy = ort_session.run(None, onnxruntime_input) + +##################################### +# Running a rollout with ONNX +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We now have an ONNX model that runs our policy. Let's compare it to the original TorchRL instance: because it is +# more lightweight, the ONNX version should be faster than the TorchRL one. + + +def onnx_policy(screen_obs: np.ndarray) -> int: + onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs} + onnxruntime_outputs = ort_session.run(None, onnxruntime_input) + action = int(onnxruntime_outputs[0]) + return action + + +with timeit("ONNX rollout"): + num_steps = 1000 + ale.reset_game() + for _ in range(num_steps): + screen_obs = ale.getScreenRGB() + action = onnx_policy(screen_obs) + reward = ale.act(action) + +with timeit("TorchRL version"), torch.no_grad(), set_exploration_type("DETERMINISTIC"): + env.rollout(num_steps, policy_explore) + +timeit.print() + +##################################### +# Note that ONNX also offers the possibility of optimizing models directly, but this is beyond the scope of this +# tutorial. +# +# Conclusion +# ---------- +# +# In this tutorial, we learned how to export TorchRL modules using various backends such as PyTorch's built-in export +# functionality, ``AOTInductor``, and ``ONNX``. +# We demonstrated how to export a policy trained on an Atari game and run it on a pytorch-free setting using the ``ALE`` +# library. We also compared the performance of the original TorchRL instance with the exported ONNX model. +# +# Key takeaways: +# +# - Exporting TorchRL modules allows for deployment on devices without PyTorch installed. +# - AOTInductor and ONNX provide alternative backends for exporting models. +# - Optimizing ONNX models can improve performance. +# +# Further reading and learning steps: +# +# - Check out the official documentation for PyTorch's `export functionality `_, +# `AOTInductor `_, and +# `ONNX `_ for more +# information. +# - Experiment with deploying exported models on different devices. +# - Explore optimization techniques for ONNX models to improve performance. +#