From 588b4e77964377797bc3fc469c12054aeb922447 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 1 Feb 2024 16:23:43 +0000 Subject: [PATCH 1/9] init --- tutorials/sphinx-tutorials/coding_ddpg.py | 196 +++++++++--------- tutorials/sphinx-tutorials/coding_dqn.py | 3 +- tutorials/sphinx-tutorials/coding_ppo.py | 185 ++++++++--------- tutorials/sphinx-tutorials/dqn_with_rnn.py | 3 +- tutorials/sphinx-tutorials/multi_task.py | 3 +- tutorials/sphinx-tutorials/pendulum.py | 3 +- .../sphinx-tutorials/pretrained_models.py | 3 +- tutorials/sphinx-tutorials/rb_tutorial.py | 3 +- tutorials/sphinx-tutorials/torchrl_demo.py | 3 +- tutorials/sphinx-tutorials/torchrl_envs.py | 3 +- 10 files changed, 200 insertions(+), 205 deletions(-) diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 85590c545fa..dddd8963c2b 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -7,6 +7,9 @@ """ ############################################################################## +# Overview +# -------- +# # TorchRL separates the training of RL algorithms in various pieces that will be # assembled in your training script: the environment, the data collection and # storage, the model and finally the loss function. @@ -14,29 +17,33 @@ # TorchRL losses (or "objectives") are stateful objects that contain the # trainable parameters (policy and value models). # This tutorial will guide you through the steps to code a loss from the ground up -# using torchrl. +# using TorchRL. # # To this aim, we will be focusing on DDPG, which is a relatively straightforward # algorithm to code. -# DDPG (`Deep Deterministic Policy Gradient `_) +# `Deep Deterministic Policy Gradient `_ (DDPG) # is a simple continuous control algorithm. It consists in learning a # parametric value function for an action-observation pair, and -# then learning a policy that outputs actions that maximise this value +# then learning a policy that outputs actions that maximize this value # function given a certain observation. # -# Key learnings: +# What you will learn: # # - how to write a loss module and customize its value estimator; -# - how to build an environment in torchrl, including transforms -# (e.g. data normalization) and parallel execution; +# - how to build an environment in TorchRL, including transforms +# (for example, data normalization) and parallel execution; # - how to design a policy and value network; # - how to collect data from your environment efficiently and store them # in a replay buffer; # - how to store trajectories (and not transitions) in your replay buffer); -# - and finally how to evaluate your model. +# - how to evaluate your model. +# +# Prerequisites +# ~~~~~~~~~~~~~ # -# This tutorial assumes that you have completed the PPO tutorial which gives -# an overview of the torchrl components and dependencies, such as +# This tutorial assumes that you have completed the +# `PPO tutorial `_ which gives +# an overview of the TorchRL components and dependencies, such as # :class:`tensordict.TensorDict` and :class:`tensordict.nn.TensorDictModules`, # although it should be # sufficiently transparent to be understood without a deep understanding of @@ -44,17 +51,20 @@ # # .. note:: # We do not aim at giving a SOTA implementation of the algorithm, but rather -# to provide a high-level illustration of torchrl's loss implementations +# to provide a high-level illustration of TorchRL's loss implementations # and the library features that are to be used in the context of # this algorithm. # # Imports and setup # ----------------- # +# .. code-block:: bash +# +# %%bash +# pip3 install torchrl mujoco glfw # sphinx_gallery_start_ignore import warnings -from typing import Tuple warnings.filterwarnings("ignore") from torch import multiprocessing @@ -62,25 +72,26 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory +is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + assert is_sphinx or (multiprocessing.get_start_method() == "fork") # sphinx_gallery_end_ignore -import torch.cuda + +import torch import tqdm ############################################################################### -# We will execute the policy on cuda if available -device = ( - torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda:0") -) +# We will execute the policy on CUDA if available +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA ############################################################################### -# torchrl :class:`~torchrl.objectives.LossModule` +# TorchRL :class:`~torchrl.objectives.LossModule` # ----------------------------------------------- # # TorchRL provides a series of losses to use in your training scripts. @@ -89,11 +100,11 @@ # # The main characteristics of TorchRL losses are: # -# - they are stateful objects: they contain a copy of the trainable parameters +# - They are stateful objects: they contain a copy of the trainable parameters # such that ``loss_module.parameters()`` gives whatever is needed to train the # algorithm. -# - They follow the ``tensordict`` convention: the :meth:`torch.nn.Module.forward` -# method will receive a tensordict as input that contains all the necessary +# - They follow the ``TensorDict`` convention: the :meth:`torch.nn.Module.forward` +# method will receive a TensorDict as input that contains all the necessary # information to return a loss value. # # >>> data = replay_buffer.sample() @@ -101,8 +112,9 @@ # # - They output a :class:`tensordict.TensorDict` instance with the loss values # written under a ``"loss_"`` where ``smth`` is a string describing the -# loss. Additional keys in the tensordict may be useful metrics to log during +# loss. Additional keys in the ``TensorDict`` may be useful metrics to log during # training time. +# # .. note:: # The reason we return independent losses is to let the user use a different # optimizer for different sets of parameters for instance. Summing the losses @@ -129,14 +141,14 @@ # # Let us start with the :meth:`~torchrl.objectives.LossModule.__init__` # method. DDPG aims at solving a control task with a simple strategy: -# training a policy to output actions that maximise the value predicted by +# training a policy to output actions that maximize the value predicted by # a value network. Hence, our loss module needs to receive two networks in its # constructor: an actor and a value networks. We expect both of these to be -# tensordict-compatible objects, such as +# TensorDict-compatible objects, such as # :class:`tensordict.nn.TensorDictModule`. # Our loss function will need to compute a target value and fit the value # network to this, and generate an action and fit the policy such that its -# value estimate is maximised. +# value estimate is maximized. # # The crucial step of the :meth:`LossModule.__init__` method is the call to # :meth:`~torchrl.LossModule.convert_to_functional`. This method will extract @@ -149,7 +161,7 @@ # model with different sets of parameters, called "trainable" and "target" # parameters. # The "trainable" parameters are those that the optimizer needs to fit. The -# "target" parameters are usually a copy of the formers with some time lag +# "target" parameters are usually a copy of the former's with some time lag # (absolute or diluted through a moving average). # These target parameters are used to compute the value associated with the # next observation. One the advantages of using a set of target parameters @@ -163,7 +175,7 @@ # accessible but this will just return a **detached** version of the # actor parameters. # -# Later, we will see how the target parameters should be updated in torchrl. +# Later, we will see how the target parameters should be updated in TorchRL. # from tensordict.nn import TensorDictModule @@ -236,11 +248,11 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): value_key = "state_action_value" if value_type == ValueEstimators.TD1: self._value_estimator = TD1Estimator( - value_network=self.actor_critic, value_key=value_key, **hp + value_network=self.actor_critic, **hp ) elif value_type == ValueEstimators.TD0: self._value_estimator = TD0Estimator( - value_network=self.actor_critic, value_key=value_key, **hp + value_network=self.actor_critic, **hp ) elif value_type == ValueEstimators.GAE: raise NotImplementedError( @@ -248,14 +260,14 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): ) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator( - value_network=self.actor_critic, value_key=value_key, **hp + value_network=self.actor_critic, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") - + self._value_estimator.set_keys(value=value_key) ############################################################################### -# The ``make_value_estimator`` method can but does not need to be called: if +# The ``make_value_estimator`` method can but does not need to be called: ifgg # not, the :class:`~torchrl.objectives.LossModule` will query this method with # its default estimator. # @@ -265,7 +277,7 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): # The central piece of an RL algorithm is the training loss for the actor. # In the case of DDPG, this function is quite simple: we just need to compute # the value associated with an action computed using the policy and optimize -# the actor weights to maximise this value. +# the actor weights to maximize this value. # # When computing this value, we must make sure to take the value parameters out # of the graph, otherwise the actor and value loss will be mixed up. @@ -302,7 +314,7 @@ def _loss_actor( def _loss_value( self, tensordict, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +): td_copy = tensordict.clone() # V(s, a) @@ -325,7 +337,7 @@ def _loss_value( tensordict, target_params=target_params ).squeeze(-1) - # Computes the value loss: L2, L1 or smooth L1 depending on self.loss_funtion + # Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function` loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function) td_error = (pred_val - target_value).pow(2) @@ -337,7 +349,7 @@ def _loss_value( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # The only missing piece is the forward method, which will glue together the -# value and actor loss, collect the cost values and write them in a tensordict +# value and actor loss, collect the cost values and write them in a ``TensorDict`` # delivered to the user. from tensordict import TensorDict, TensorDictBase @@ -397,7 +409,7 @@ class DDPGLoss(LossModule): # For this example, we will be using the ``"cheetah"`` task. The goal is to make # a half-cheetah run as fast as possible. # -# In TorchRL, one can create such a task by relying on dm_control or gym: +# In TorchRL, one can create such a task by relying on ``dm_control`` or ``gym``: # # .. code-block:: python # @@ -411,7 +423,7 @@ class DDPGLoss(LossModule): # # By default, these environment disable rendering. Training from states is # usually easier than training from images. To keep things simple, we focus -# on learning from states only. To pass the pixels to the tensordicts that +# on learning from states only. To pass the pixels to the ``tensordicts`` that # are collected by :func:`env.step()`, simply pass the ``from_pixels=True`` # argument to the constructor: # @@ -420,7 +432,7 @@ class DDPGLoss(LossModule): # env = GymEnv("HalfCheetah-v4", from_pixels=True, pixels_only=True) # # We write a :func:`make_env` helper function that will create an environment -# with either one of the two backends considered above (dm-control or gym). +# with either one of the two backends considered above (``dm-control`` or ``gym``). # from torchrl.envs.libs.dm_control import DMControlEnv @@ -431,7 +443,7 @@ class DDPGLoss(LossModule): def make_env(from_pixels=False): - """Create a base env.""" + """Create a base ``env``.""" global env_library global env_name @@ -502,7 +514,7 @@ def make_env(from_pixels=False): def make_transformed_env( env, ): - """Apply transforms to the env (such as reward scaling and state normalization).""" + """Apply transforms to the ``env`` (such as reward scaling and state normalization).""" env = TransformedEnv(env) @@ -511,16 +523,6 @@ def make_transformed_env( # syntax. env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) - double_to_float_list = [] - double_to_float_inv_list = [] - if env_library is DMControlEnv: - # DMControl requires double-precision - double_to_float_list += [ - "reward", - "action", - ] - double_to_float_inv_list += ["action"] - # We concatenate all states into a single "observation_vector" # even if there is a single tensor, it'll be renamed in "observation_vector". # This facilitates the downstream operations as we know the name of the @@ -536,16 +538,14 @@ def make_transformed_env( # version of the transform env.append_transform(ObservationNorm(in_keys=[out_key], standard_normal=True)) - double_to_float_list.append(out_key) env.append_transform( - DoubleToFloat( - in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list - ) + DoubleToFloat() ) env.append_transform(StepCounter(max_frames_per_traj)) - # We need a marker for the start of trajectories for our OU exploration: + # We need a marker for the start of trajectories for our Ornstein-Uhlenbeck (OU) + # exploration: env.append_transform(InitTracker()) return env @@ -608,15 +608,16 @@ def make_t_env(): return env -# The backend can be gym or dm_control +# The backend can be ``gym`` or ``dm_control`` backend = "gym" ############################################################################### # .. note:: +# # ``frame_skip`` batches multiple step together with a single action -# If > 1, the other frame counts (e.g. frames_per_batch, total_frames) need to -# be adjusted to have a consistent total number of frames collected across -# experiments. This is important as raising the frame-skip but keeping the +# If > 1, the other frame counts (for example, frames_per_batch, total_frames) +# need to be adjusted to have a consistent total number of frames collected +# across experiments. This is important as raising the frame-skip but keeping the # total number of frames unchanged may seem like cheating: all things compared, # a dataset of 10M elements collected with a frame-skip of 2 and another with # a frame-skip of 1 actually have a ratio of interactions with the environment @@ -630,7 +631,7 @@ def make_t_env(): ############################################################################### # We also define when a trajectory will be truncated. A thousand steps (500 if -# frame-skip = 2) is a good number to use for cheetah: +# frame-skip = 2) is a good number to use for the cheetah task: max_frames_per_traj = 500 @@ -660,7 +661,7 @@ def get_env_stats(): ############################################################################### # Normalization stats # ~~~~~~~~~~~~~~~~~~~ -# Number of random steps used as for stats computation using ObservationNorm +# Number of random steps used as for stats computation using ``ObservationNorm`` init_env_steps = 5000 @@ -764,8 +765,8 @@ def make_ddpg_actor( module=q_net, ).to(device) - # init lazy moduless - qnet(actor(proof_environment.reset())) + # initialize lazy modules + qnet(actor(proof_environment.reset().to(device))) return actor, qnet @@ -779,7 +780,7 @@ def make_ddpg_actor( # ~~~~~~~~~~~ # # The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper` -# exploration module, as suggesed in the original paper. +# exploration module, as suggested in the original paper. # Let's define the number of frames before OU noise reaches its minimum value annealing_frames = 1_000_000 @@ -801,24 +802,27 @@ def make_ddpg_actor( # environment and reset it when required. # Data collectors are designed to help developers have a tight control # on the number of frames per batch of data, on the (a)sync nature of this -# collection and on the resources allocated to the data collection (e.g. GPU, -# number of workers etc). +# collection and on the resources allocated to the data collection (for example +# GPU, number of workers, and so on). # # Here we will use -# :class:`~torchrl.collectors.MultiaSyncDataCollector`, a data collector that -# will be executed in an async manner (i.e. data will be collected while -# the policy is being optimized). With the :class:`MultiaSyncDataCollector`, -# multiple workers are running rollouts separately. When a batch is asked, it -# is gathered from the first worker that can provide it. +# :class:`~torchrl.collectors.SyncDataCollector`, a simple, single-process +# data collector. TorchRL offers other collectors, such as +# :class:`~torchrl.collectors.MultiaSyncDataCollector`, which executed the +# rollouts in an asynchronous manner (for example, data will be collected while +# the policy is being optimized, thereby decoupling the training and +# data collection). # # The parameters to specify are: # -# - the list of environment creation functions, +# - an environment factory or an environment, # - the policy, # - the total number of frames before the collector is considered empty, # - the maximum number of frames per trajectory (useful for non-terminating -# environments, like dm_control ones). +# environments, like ``dm_control`` ones). +# # .. note:: +# # The ``max_frames_per_traj`` passed to the collector will have the effect # of registering a new :class:`~torchrl.envs.StepCounter` transform # with the environment used for inference. We can achieve the same result @@ -837,8 +841,8 @@ def make_ddpg_actor( ############################################################################### # The number of frames returned by the collector at each iteration of the outer -# loop is equal to the length of each sub-trajectories times the number of envs -# run in parallel in each collector. +# loop is equal to the length of each sub-trajectories times the number of +# environments run in parallel in each collector. # # In other words, we expect batches from the collector to have a shape # ``[env_per_collector, traj_len]`` where @@ -849,26 +853,18 @@ def make_ddpg_actor( init_random_frames = 5000 num_collectors = 2 -from torchrl.collectors import MultiaSyncDataCollector +from torchrl.collectors import SyncDataCollector from torchrl.envs import ExplorationType -collector = MultiaSyncDataCollector( - create_env_fn=[ - parallel_env, - ] - * num_collectors, +collector = SyncDataCollector( + parallel_env, policy=actor_model_explore, total_frames=total_frames, - # max_frames_per_traj=max_frames_per_traj, # this is achieved by the env constructor frames_per_batch=frames_per_batch, init_random_frames=init_random_frames, reset_at_each_iter=False, split_trajs=False, - device=device, - # device for execution - storing_device=device, - # device where data will be stored and passed - update_at_each_batch=False, + device=collector_device, exploration_type=ExplorationType.RANDOM, ) @@ -961,7 +957,7 @@ def make_replay_buffer(buffer_size, batch_size, random_crop_len, prefetch=3, prb ############################################################################### -# We'll store the replay buffer in a temporary dirrectory on disk +# We'll store the replay buffer in a temporary directory on disk import tempfile @@ -977,17 +973,17 @@ def make_replay_buffer(buffer_size, batch_size, random_crop_len, prefetch=3, prb # size by dividing it by the length of the sub-trajectories yielded by our # data collector. # Regarding the batch-size, our sampling strategy will consist in sampling -# trajectories of length ``traj_len=200`` before selecting sub-trajecotries +# trajectories of length ``traj_len=200`` before selecting sub-trajectories # or length ``random_crop_len=25`` on which the loss will be computed. # This strategy balances the choice of storing whole trajectories of a certain -# length with the need for providing sampels with a sufficient heterogeneity +# length with the need for providing samples with a sufficient heterogeneity # to our loss. The following figure shows the dataflow from a collector # that gets 8 frames in each batch with 2 environments run in parallel, # feeds them to a replay buffer that contains 1000 trajectories and # samples sub-trajectories of 2 time steps each. # # .. figure:: /_static/img/replaybuffer_traj.png -# :alt: Storign trajectories in the replay buffer +# :alt: Storing trajectories in the replay buffer # # Let's start with the number of frames stored in the buffer @@ -1005,7 +1001,7 @@ def ceil_div(x, y): ############################################################################### # We also need to define how many updates we'll be doing per batch of data -# collected. This is known as the update-to-data or UTD ratio: +# collected. This is known as the update-to-data or ``UTD`` ratio: update_to_data = 64 ############################################################################### @@ -1032,7 +1028,7 @@ def ceil_div(x, y): # Loss module construction # ------------------------ # -# We build our loss module with the actor and qnet we've just created. +# We build our loss module with the actor and ``qnet`` we've just created. # Because we have target parameters to update, we _must_ create a target network # updater. # @@ -1189,7 +1185,7 @@ def ceil_div(x, y): # # .. note:: # As already mentioned above, to get a more reasonable performance, -# use a greater value for ``total_frames`` e.g. 1M. +# use a greater value for ``total_frames`` for example, 1M. from matplotlib import pyplot as plt @@ -1205,7 +1201,7 @@ def ceil_div(x, y): # Conclusion # ---------- # -# In this tutorial, we have learnt how to code a loss module in TorchRL given +# In this tutorial, we have learned how to code a loss module in TorchRL given # the concrete example of DDPG. # # The key takeaways are: @@ -1215,3 +1211,11 @@ def ceil_div(x, y): # - How to use (or not) a target network, and how to update its parameters; # - How to create an optimizer associated with a loss module. # +# Next Steps +# ---------- +# +# To iterate further on this loss module we might consider: +# +# - Using `@dispatch` (see `[Feature] Distpatch IQL loss module `_.) +# - Allowing flexible TensorDict keys. +# diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index fcddd699b3a..e325a8af2a6 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -93,10 +93,11 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory +is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + assert is_sphinx or (multiprocessing.get_start_method() == "fork") # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 679d625220c..b08832307a1 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -15,8 +15,8 @@ Key learnings: -- How to create an environment in TorchRL, transform its outputs, and collect data from this env; -- How to make your classes talk to each other using :class:`tensordict.TensorDict`; +- How to create an environment in TorchRL, transform its outputs, and collect data from this environment; +- How to make your classes talk to each other using :class:`~tensordict.TensorDict`; - The basics of building your training loop with TorchRL: - How to compute the advantage signal for policy gradient methods; @@ -56,7 +56,7 @@ # problem rather than re-inventing the wheel every time you want to train a policy. # # For completeness, here is a brief overview of what the loss computes, even though -# this is taken care of by our :class:`ClipPPOLoss` module—the algorithm works as follows: +# this is taken care of by our :class:`~torchrl.objectives.ClipPPOLoss` module—the algorithm works as follows: # 1. we will sample a batch of data by playing the # policy in the environment for a given number of steps. # 2. Then, we will perform a given number of optimization steps with random sub-samples of this batch using @@ -99,7 +99,7 @@ # 5. Finally, we will run our training loop and analyze the results. # # Throughout this tutorial, we'll be using the :mod:`tensordict` library. -# :class:`tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract +# :class:`~tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract # what a module reads and writes and care less about the specific data # description and more about the algorithm itself. # @@ -113,10 +113,11 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory +is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + assert is_sphinx or (multiprocessing.get_start_method() == "fork") # sphinx_gallery_end_ignore @@ -159,7 +160,7 @@ # actually return ``frame_skip`` frames). # -device = "cpu" if not torch.has_cuda else "cuda:0" +device = "cpu" if not torch.cuda.is_available() else "cuda:0" num_cells = 256 # number of cells in each layer i.e. output dim. lr = 3e-4 max_grad_norm = 1.0 @@ -174,22 +175,10 @@ # use. In general, the goal of an RL algorithm is to learn to solve the task # as fast as it can in terms of environment interactions: the lower the ``total_frames`` # the better. -# We also define a ``frame_skip``: in some contexts, repeating the same action -# multiple times over the course of a trajectory may be beneficial as it makes -# the behavior more consistent and less erratic. However, "skipping" -# too many frames will hamper training by reducing the reactivity of the actor -# to observation changes. -# -# When using ``frame_skip`` it is good practice to -# correct the other frame counts by the number of frames we are grouping -# together. If we configure a total count of X frames for training but -# use a ``frame_skip`` of Y, we will be actually collecting XY frames in total -# which exceeds our predefined budget. -# -frame_skip = 1 -frames_per_batch = 1000 // frame_skip +# +frames_per_batch = 1000 # For a complete training, bring the number of frames up to 1M -total_frames = 10_000 // frame_skip +total_frames = 10_000 ###################################################################### # PPO parameters @@ -220,23 +209,23 @@ # control system. Various libraries provide simulation environments for reinforcement # learning, including Gymnasium (previously OpenAI Gym), DeepMind control suite, and # many others. -# As a generalistic library, TorchRL's goal is to provide an interchangeable interface +# As a general library, TorchRL's goal is to provide an interchangeable interface # to a large panel of RL simulators, allowing you to easily swap one environment # with another. For example, creating a wrapped gym environment can be achieved with few characters: # -base_env = GymEnv("InvertedDoublePendulum-v4", device=device, frame_skip=frame_skip) +base_env = GymEnv("InvertedDoublePendulum-v4", device=device) ###################################################################### # There are a few things to notice in this code: first, we created # the environment by calling the ``GymEnv`` wrapper. If extra keyword arguments # are passed, they will be transmitted to the ``gym.make`` method, hence covering -# the most common env construction commands. +# the most common environment construction commands. # Alternatively, one could also directly create a gym environment using ``gym.make(env_name, **kwargs)`` # and wrap it in a `GymWrapper` class. # # Also the ``device`` argument: for gym, this only controls the device where -# input action and observered states will be stored, but the execution will always +# input action and observed states will be stored, but the execution will always # be done on CPU. The reason for this is simply that gym does not support on-device # execution, unless specified otherwise. For other libraries, we have control over # the execution device and, as much as we can, we try to stay consistent in terms of @@ -248,9 +237,9 @@ # We will append some transforms to our environments to prepare the data for # the policy. In Gym, this is usually achieved via wrappers. TorchRL takes a different # approach, more similar to other pytorch domain libraries, through the use of transforms. -# To add transforms to an environment, one should simply wrap it in a :class:`TransformedEnv` -# instance and append the sequence of transforms to it. The transformed env will inherit -# the device and meta-data of the wrapped env, and transform these depending on the sequence +# To add transforms to an environment, one should simply wrap it in a :class:`~torchrl.envs.transforms.TransformedEnv` +# instance and append the sequence of transforms to it. The transformed environment will inherit +# the device and meta-data of the wrapped environment, and transform these depending on the sequence # of transforms it contains. # # Normalization @@ -262,17 +251,17 @@ # run a certain number of random steps in the environment and compute # the summary statistics of these observations. # -# We'll append two other transforms: the :class:`DoubleToFloat` transform will +# We'll append two other transforms: the :class:`~torchrl.envs.transforms.DoubleToFloat` transform will # convert double entries to single-precision numbers, ready to be read by the -# policy. The :class:`StepCounter` transform will be used to count the steps before +# policy. The :class:`~torchrl.envs.transforms.StepCounter` transform will be used to count the steps before # the environment is terminated. We will use this measure as a supplementary measure # of performance. # -# As we will see later, many of the TorchRL's classes rely on :class:`tensordict.TensorDict` +# As we will see later, many of the TorchRL's classes rely on :class:`~tensordict.TensorDict` # to communicate. You could think of it as a python dictionary with some extra # tensor features. In practice, this means that many modules we will be working # with need to be told what key to read (``in_keys``) and what key to write -# (``out_keys``) in the tensordict they will receive. Usually, if ``out_keys`` +# (``out_keys``) in the ``tensordict`` they will receive. Usually, if ``out_keys`` # is omitted, it is assumed that the ``in_keys`` entries will be updated # in-place. For our transforms, the only entry we are interested in is referred # to as ``"observation"`` and our transform layers will be told to modify this @@ -284,22 +273,20 @@ Compose( # normalize observations ObservationNorm(in_keys=["observation"]), - DoubleToFloat( - in_keys=["observation"], - ), + DoubleToFloat(), StepCounter(), ), ) ###################################################################### # As you may have noticed, we have created a normalization layer but we did not -# set its normalization parameters. To do this, :class:`ObservationNorm` can +# set its normalization parameters. To do this, :class:`~torchrl.envs.transforms.ObservationNorm` can # automatically gather the summary statistics of our environment: # env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0) ###################################################################### -# The :class:`ObservationNorm` transform has now been populated with a +# The :class:`~torchrl.envs.transforms.ObservationNorm` transform has now been populated with a # location and a scale that will be used to normalize the data. # # Let us do a little sanity check for the shape of our summary stats: @@ -313,25 +300,23 @@ # For efficiency purposes, TorchRL is quite stringent when it comes to # environment specs, but you can easily check that your environment specs are # adequate. -# In our example, the :class:`GymWrapper` and :class:`GymEnv` that inherits -# from it already take care of setting the proper specs for your env so +# In our example, the :class:`~torchrl.envs.libs.gym.GymWrapper` and +# :class:`~torchrl.envs.libs.gym.GymEnv` that inherits +# from it already take care of setting the proper specs for your environment so # you should not have to care about this. # # Nevertheless, let's see a concrete example using our transformed # environment by looking at its specs. -# There are five specs to look at: ``observation_spec`` which defines what +# There are three specs to look at: ``observation_spec`` which defines what # is to be expected when executing an action in the environment, -# ``reward_spec`` which indicates the reward domain, -# ``done_spec`` which indicates the done state of an environment, -# the ``action_spec`` which defines the action space, dtype and device and -# the ``state_spec`` which groups together the specs of all the other inputs -# (if any) to the environment. +# ``reward_spec`` which indicates the reward domain and finally the +# ``input_spec`` (which contains the ``action_spec``) and which represents +# everything an environment requires to execute a single step. # print("observation_spec:", env.observation_spec) print("reward_spec:", env.reward_spec) -print("done_spec:", env.done_spec) -print("action_spec:", env.action_spec) -print("state_spec:", env.state_spec) +print("input_spec:", env.input_spec) +print("action_spec (as defined by input_spec):", env.action_spec) ###################################################################### # the :func:`check_env_specs` function runs a small rollout and compares its output against the environment @@ -349,9 +334,9 @@ # action as input, and outputs an observation, a reward and a done state. The # observation may be composite, meaning that it could be composed of more than one # tensor. This is not a problem for TorchRL, since the whole set of observations -# is automatically packed in the output :class:`tensordict.TensorDict`. After executing a rollout -# (ie a sequence of environment steps and random action generations) over a given -# number of steps, we will retrieve a :class:`tensordict.TensorDict` instance with a shape +# is automatically packed in the output :class:`~tensordict.TensorDict`. After executing a rollout +# (for example, a sequence of environment steps and random action generations) over a given +# number of steps, we will retrieve a :class:`~tensordict.TensorDict` instance with a shape # that matches this trajectory length: # rollout = env.rollout(3) @@ -361,8 +346,8 @@ ###################################################################### # Our rollout data has a shape of ``torch.Size([3])``, which matches the number of steps # we ran it for. The ``"next"`` entry points to the data coming after the current step. -# In most cases, the ``"next""`` data at time `t` matches the data at ``t+1``, but this -# may not be the case if we are using some specific transformations (e.g. multi-step). +# In most cases, the ``"next"`` data at time `t` matches the data at ``t+1``, but this +# may not be the case if we are using some specific transformations (for example, multi-step). # # Policy # ------ @@ -388,10 +373,9 @@ # # 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``. # -# 2. Append a :class:`NormalParamExtractor` to extract a location and a scale (ie splits the input in two equal parts -# and applies a positive transformation to the scale parameter). +# 2. Append a :class:`~tensordict.nn.distributions.NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts and applies a positive transformation to the scale parameter). # -# 3. Create a probabilistic :class:`TensorDictModule` that can generate this distribution and sample from it. +# 3. Create a probabilistic :class:`~tensordict.nn.TensorDictModule` that can generate this distribution and sample from it. # actor_net = nn.Sequential( @@ -406,8 +390,8 @@ ) ###################################################################### -# To enable the policy to "talk" with the environment through the tensordict -# data carrier, we wrap the ``nn.Module`` in a :class:`TensorDictModule`. This +# To enable the policy to "talk" with the environment through the ``tensordict`` +# data carrier, we wrap the ``nn.Module`` in a :class:`~tensordict.nn.TensorDictModule`. This # class will simply ready the ``in_keys`` it is provided with and write the # outputs in-place at the registered ``out_keys``. # @@ -417,18 +401,19 @@ ###################################################################### # We now need to build a distribution out of the location and scale of our -# normal distribution. To do so, we instruct the :class:`ProbabilisticActor` -# class to build a :class:`TanhNormal` out of the location and scale +# normal distribution. To do so, we instruct the +# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` +# class to build a :class:`~torchrl.modules.TanhNormal` out of the location and scale # parameters. We also provide the minimum and maximum values of this # distribution, which we gather from the environment specs. # # The name of the ``in_keys`` (and hence the name of the ``out_keys`` from -# the :class:`TensorDictModule` above) cannot be set to any value one may -# like, as the :class:`TanhNormal` distribution constructor will expect the +# the :class:`~tensordict.nn.TensorDictModule` above) cannot be set to any value one may +# like, as the :class:`~torchrl.modules.TanhNormal` distribution constructor will expect the # ``loc`` and ``scale`` keyword arguments. That being said, -# :class:`ProbabilisticActor` also accepts ``Dict[str, str]`` typed ``in_keys`` -# where the key-value pair indicates what ``in_key`` string should be used for -# every keyword argument that is to be used. +# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` also accepts +# ``Dict[str, str]`` typed ``in_keys`` where the key-value pair indicates +# what ``in_key`` string should be used for every keyword argument that is to be used. # policy_module = ProbabilisticActor( module=policy_module, @@ -436,8 +421,8 @@ in_keys=["loc", "scale"], distribution_class=TanhNormal, distribution_kwargs={ - "min": env.action_spec.space.minimum, - "max": env.action_spec.space.maximum, + "min": env.action_spec.space.low, + "max": env.action_spec.space.high, }, return_log_prob=True, # we'll need the log-prob for the numerator of the importance weights @@ -451,7 +436,7 @@ # won't be used at inference time. This module will read the observations and # return an estimation of the discounted return for the following trajectory. # This allows us to amortize learning by relying on the some utility estimation -# that is learnt on-the-fly during training. Our value network share the same +# that is learned on-the-fly during training. Our value network share the same # structure as the policy, but for simplicity we assign it its own set of # parameters. # @@ -472,7 +457,7 @@ ###################################################################### # let's try our policy and value modules. As we said earlier, the usage of -# :class:`TensorDictModule` makes it possible to directly read the output +# :class:`~tensordict.nn.TensorDictModule` makes it possible to directly read the output # of the environment to run these modules, as they know what information to read # and where to write it: # @@ -483,11 +468,11 @@ # Data collector # -------------- # -# TorchRL provides a set of :class:`DataCollector` classes. Briefly, these -# classes execute three operations: reset an environment, compute an action -# given the latest observation, execute a step in the environment, and repeat -# the last two steps until the environment signals a stop (or reaches a done -# state). +# TorchRL provides a set of `DataCollector classes `__. +# Briefly, these classes execute three operations: reset an environment, +# compute an action given the latest observation, execute a step in the environment, +# and repeat the last two steps until the environment signals a stop (or reaches +# a done state). # # They allow you to control how many frames to collect at each iteration # (through the ``frames_per_batch`` parameter), @@ -495,18 +480,19 @@ # on which ``device`` the policy should be executed, etc. They are also # designed to work efficiently with batched and multiprocessed environments. # -# The simplest data collector is the :class:`SyncDataCollector`: it is an -# iterator that you can use to get batches of data of a given length, and +# The simplest data collector is the :class:`~torchrl.collectors.collectors.SyncDataCollector`: +# it is an iterator that you can use to get batches of data of a given length, and # that will stop once a total number of frames (``total_frames``) have been # collected. -# Other data collectors (``MultiSyncDataCollector`` and -# ``MultiaSyncDataCollector``) will execute the same operations in synchronous -# and asynchronous manner over a set of multiprocessed workers. +# Other data collectors (:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` and +# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`) will execute +# the same operations in synchronous and asynchronous manner over a +# set of multiprocessed workers. # # As for the policy and environment before, the data collector will return -# :class:`tensordict.TensorDict` instances with a total number of elements that will -# match ``frames_per_batch``. Using :class:`tensordict.TensorDict` to pass data to the -# training loop allows you to write dataloading pipelines +# :class:`~tensordict.TensorDict` instances with a total number of elements that will +# match ``frames_per_batch``. Using :class:`~tensordict.TensorDict` to pass data to the +# training loop allows you to write data loading pipelines # that are 100% oblivious to the actual specificities of the rollout content. # collector = SyncDataCollector( @@ -528,10 +514,10 @@ # of epochs. # # TorchRL's replay buffers are built using a common container -# :class:`ReplayBuffer` which takes as argument the components of the buffer: -# a storage, a writer, a sampler and possibly some transforms. Only the -# storage (which indicates the replay buffer capacity) is mandatory. We -# also specify a sampler without repetition to avoid sampling multiple times +# :class:`~torchrl.data.ReplayBuffer` which takes as argument the components +# of the buffer: a storage, a writer, a sampler and possibly some transforms. +# Only the storage (which indicates the replay buffer capacity) is mandatory. +# We also specify a sampler without repetition to avoid sampling multiple times # the same item in one epoch. # Using a replay buffer for PPO is not mandatory and we could simply # sample the sub-batches from the collected batch, but using these classes @@ -539,7 +525,7 @@ # replay_buffer = ReplayBuffer( - storage=LazyTensorStorage(frames_per_batch), + storage=LazyTensorStorage(max_size=frames_per_batch), sampler=SamplerWithoutReplacement(), ) @@ -547,8 +533,8 @@ # Loss function # ------------- # -# The PPO loss can be directly imported from torchrl for convenience using the -# :class:`ClipPPOLoss` class. This is the easiest way of utilizing PPO: +# The PPO loss can be directly imported from TorchRL for convenience using the +# :class:`~torchrl.objectives.ClipPPOLoss` class. This is the easiest way of utilizing PPO: # it hides away the mathematical operations of PPO and the control flow that # goes with it. # @@ -558,11 +544,11 @@ # To compute the advantage, one just needs to (1) build the advantage module, which # utilizes our value operator, and (2) pass each batch of data through it before each # epoch. -# The GAE module will update the input :class:`TensorDict` with new ``"advantage"`` and +# The GAE module will update the input ``tensordict`` with new ``"advantage"`` and # ``"value_target"`` entries. # The ``"value_target"`` is a gradient-free tensor that represents the empirical # value that the value network should represent with the input observation. -# Both of these will be used by :class:`ClipPPOLoss` to +# Both of these will be used by :class:`~torchrl.objectives.ClipPPOLoss` to # return the policy and value losses. # @@ -577,9 +563,7 @@ entropy_bonus=bool(entropy_eps), entropy_coef=entropy_eps, # these keys match by default but we set this for completeness - value_target_key=advantage_module.value_target_key, critic_coef=1.0, - gamma=0.99, loss_critic_type="smooth_l1", ) @@ -610,7 +594,7 @@ logs = defaultdict(list) -pbar = tqdm(total=total_frames * frame_skip) +pbar = tqdm(total=total_frames) eval_str = "" # We iterate over the collector until it reaches the total number of frames it was @@ -621,8 +605,7 @@ # We'll need an "advantage" signal to make PPO work. # We re-compute it at each epoch as its value depends on the value # network which is updated in the inner loop. - with torch.no_grad(): - advantage_module(tensordict_data) + advantage_module(tensordict_data) data_view = tensordict_data.reshape(-1) replay_buffer.extend(data_view.cpu()) for _ in range(frames_per_batch // sub_batch_size): @@ -634,7 +617,7 @@ + loss_vals["loss_entropy"] ) - # Optimization: backward, grad clipping and optim step + # Optimization: backward, grad clipping and optimization step loss_value.backward() # this is not strictly mandatory but it's good practice to keep # your gradient norm bounded @@ -643,7 +626,7 @@ optim.zero_grad() logs["reward"].append(tensordict_data["next", "reward"].mean().item()) - pbar.update(tensordict_data.numel() * frame_skip) + pbar.update(tensordict_data.numel()) cum_reward_str = ( f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})" ) @@ -655,8 +638,8 @@ # We evaluate the policy once every 10 batches of data. # Evaluation is rather simple: execute the policy without exploration # (take the expected value of the action distribution) for a given - # number of steps (1000, which is our env horizon). - # The ``rollout`` method of the env can take a policy as argument: + # number of steps (1000, which is our ``env`` horizon). + # The ``rollout`` method of the ``env`` can take a policy as argument: # it will then execute this policy at each step. with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): # execute a rollout with the trained policy @@ -717,7 +700,7 @@ # we could run several simulations in parallel to speed up data collection. # Check :class:`~torchrl.envs.ParallelEnv` for further information. # -# * From a logging perspective, one could add a :class:`~torchrl.record.VideoRecorder` transform to +# * From a logging perspective, one could add a :class:`torchrl.record.VideoRecorder` transform to # the environment after asking for rendering to get a visual rendering of the # inverted pendulum in action. Check :py:mod:`torchrl.record` to # know more. diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index a1c82d5c429..f7c8158962b 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -77,10 +77,11 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory +is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + assert is_sphinx or (multiprocessing.get_start_method() == "fork") # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index a12c2b05ff8..515c9f0f934 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -19,10 +19,11 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory +is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + assert is_sphinx or (multiprocessing.get_start_method() == "fork") # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index 12c8bdc3193..12f481be032 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -83,10 +83,11 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory +is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + assert is_sphinx or (multiprocessing.get_start_method() == "fork") # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/pretrained_models.py b/tutorials/sphinx-tutorials/pretrained_models.py index e8abf33cef8..d60de247c9f 100644 --- a/tutorials/sphinx-tutorials/pretrained_models.py +++ b/tutorials/sphinx-tutorials/pretrained_models.py @@ -23,10 +23,11 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory +is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + assert is_sphinx or (multiprocessing.get_start_method() == "fork") # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 6106e3cf65a..6a8c69d00a4 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -56,10 +56,11 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory +is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + assert is_sphinx or (multiprocessing.get_start_method() == "fork") # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index d1a261e63f5..68f36922e8f 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -134,10 +134,11 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory +is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + assert is_sphinx or (multiprocessing.get_start_method() == "fork") # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index dc836b43150..e2f3384ac30 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -36,10 +36,11 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory +is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert multiprocessing.get_start_method() == "fork" + assert is_sphinx or (multiprocessing.get_start_method() == "fork") # sphinx_gallery_end_ignore From e3cd0b8f30bec49b9f018f2010831a74f5712837 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 1 Feb 2024 17:03:58 +0000 Subject: [PATCH 2/9] amend --- tutorials/sphinx-tutorials/coding_ddpg.py | 22 ++++++------------- tutorials/sphinx-tutorials/coding_dqn.py | 3 +-- tutorials/sphinx-tutorials/coding_ppo.py | 3 +-- tutorials/sphinx-tutorials/dqn_with_rnn.py | 3 +-- tutorials/sphinx-tutorials/multi_task.py | 3 +-- tutorials/sphinx-tutorials/pendulum.py | 3 +-- .../sphinx-tutorials/pretrained_models.py | 3 +-- tutorials/sphinx-tutorials/rb_tutorial.py | 3 +-- tutorials/sphinx-tutorials/torchrl_demo.py | 3 +-- tutorials/sphinx-tutorials/torchrl_envs.py | 3 +-- 10 files changed, 16 insertions(+), 33 deletions(-) diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index dddd8963c2b..a131235d1d8 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -72,11 +72,10 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory -is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert is_sphinx or (multiprocessing.get_start_method() == "fork") + pass # sphinx_gallery_end_ignore @@ -87,7 +86,7 @@ ############################################################################### # We will execute the policy on CUDA if available -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA ############################################################################### @@ -247,25 +246,20 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): hp.update(hyperparams) value_key = "state_action_value" if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator( - value_network=self.actor_critic, **hp - ) + self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator( - value_network=self.actor_critic, **hp - ) + self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp) elif value_type == ValueEstimators.GAE: raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." ) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator( - value_network=self.actor_critic, **hp - ) + self._value_estimator = TDLambdaEstimator(value_network=self.actor_critic, **hp) else: raise NotImplementedError(f"Unknown value type {value_type}") self._value_estimator.set_keys(value=value_key) + ############################################################################### # The ``make_value_estimator`` method can but does not need to be called: ifgg # not, the :class:`~torchrl.objectives.LossModule` will query this method with @@ -538,9 +532,7 @@ def make_transformed_env( # version of the transform env.append_transform(ObservationNorm(in_keys=[out_key], standard_normal=True)) - env.append_transform( - DoubleToFloat() - ) + env.append_transform(DoubleToFloat()) env.append_transform(StepCounter(max_frames_per_traj)) diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index e325a8af2a6..d4f7f4798c6 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -93,11 +93,10 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory -is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert is_sphinx or (multiprocessing.get_start_method() == "fork") + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index b08832307a1..5c157a93a97 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -113,11 +113,10 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory -is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert is_sphinx or (multiprocessing.get_start_method() == "fork") + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index f7c8158962b..2d3b8c970d4 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -77,11 +77,10 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory -is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert is_sphinx or (multiprocessing.get_start_method() == "fork") + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index 515c9f0f934..b3ac87cc921 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -19,11 +19,10 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory -is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert is_sphinx or (multiprocessing.get_start_method() == "fork") + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index 12f481be032..a9f89641183 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -83,11 +83,10 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory -is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert is_sphinx or (multiprocessing.get_start_method() == "fork") + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/pretrained_models.py b/tutorials/sphinx-tutorials/pretrained_models.py index d60de247c9f..f67475d16bb 100644 --- a/tutorials/sphinx-tutorials/pretrained_models.py +++ b/tutorials/sphinx-tutorials/pretrained_models.py @@ -23,11 +23,10 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory -is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert is_sphinx or (multiprocessing.get_start_method() == "fork") + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 6a8c69d00a4..63802cf0052 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -56,11 +56,10 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory -is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert is_sphinx or (multiprocessing.get_start_method() == "fork") + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 68f36922e8f..f137a499ef4 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -134,11 +134,10 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory -is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert is_sphinx or (multiprocessing.get_start_method() == "fork") + pass # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index e2f3384ac30..84a5a499723 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -36,11 +36,10 @@ # TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory -is_sphinx = 'sphinx_gallery_conf' in globals() try: multiprocessing.set_start_method("fork") except RuntimeError: - assert is_sphinx or (multiprocessing.get_start_method() == "fork") + pass # sphinx_gallery_end_ignore From 02bd500b3bf032d30179a1805a266b7f5fabd5da Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 1 Feb 2024 21:02:57 +0000 Subject: [PATCH 3/9] amend --- tutorials/sphinx-tutorials/coding_dqn.py | 16 ++++++++++++---- tutorials/sphinx-tutorials/dqn_with_rnn.py | 20 ++++++++++++++------ tutorials/sphinx-tutorials/rb_tutorial.py | 17 +++++++++++++---- tutorials/sphinx-tutorials/torchrl_envs.py | 4 ++-- 4 files changed, 41 insertions(+), 16 deletions(-) diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index d4f7f4798c6..d543dd0d876 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -86,6 +86,8 @@ import tempfile import warnings +from tensordict.nn import TensorDictSequential + warnings.filterwarnings("ignore") from torch import multiprocessing @@ -125,7 +127,7 @@ ToTensorImage, TransformedEnv, ) -from torchrl.modules import DuelingCnnDQNet, EGreedyWrapper, QValueActor +from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor from torchrl.objectives import DQNLoss, SoftUpdate from torchrl.record.loggers.csv import CSVLogger @@ -328,13 +330,13 @@ def make_model(dummy_env): tensordict = dummy_env.fake_tensordict() actor(tensordict) - # we wrap our actor in an EGreedyWrapper for data collection - actor_explore = EGreedyWrapper( - actor, + # we join our actor with an EGreedyModule for data collection + exploration_module = EGreedyModule( annealing_num_steps=total_frames, eps_init=eps_greedy_val, eps_end=eps_greedy_val_env, ) + actor_explore = TensorDictSequential(actor, exploration_module) return actor, actor_explore @@ -642,6 +644,12 @@ def get_loss_module(actor, gamma): ) recorder.register(trainer) +############################################################################### +# The exploration module epsilon factor is also annealed: +# + +trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch) + ############################################################################### # - Any callable (including :class:`~torchrl.trainers.TrainerHookBase` # subclasses) can be registered using :meth:`~torchrl.trainers.Trainer.register_op`. diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index 2d3b8c970d4..8b4022ded39 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -86,7 +86,11 @@ import torch import tqdm -from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq +from tensordict.nn import ( + TensorDictModule as Mod, + TensorDictSequential, + TensorDictSequential as Seq, +) from torch import nn from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer @@ -104,7 +108,7 @@ TransformedEnv, ) from torchrl.envs.libs.gym import GymEnv -from torchrl.modules import ConvNet, EGreedyWrapper, LSTMModule, MLP, QValueModule +from torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule from torchrl.objectives import DQNLoss, SoftUpdate device = torch.device(0) if torch.cuda.device_count() else torch.device("cpu") @@ -309,11 +313,15 @@ # DQN being a deterministic algorithm, exploration is a crucial part of it. # We'll be using an :math:`\epsilon`-greedy policy with an epsilon of 0.2 decaying # progressively to 0. -# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyWrapper.step` +# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyModule.step` # (see training loop below). # -stoch_policy = EGreedyWrapper( - stoch_policy, annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2 +exploration_module = EGreedyModule( + annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2 +) +stoch_policy = TensorDictSequential( + stoch_policy, + exploration_module, ) ###################################################################### @@ -419,7 +427,7 @@ pbar.set_description( f"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}" ) - stoch_policy.step(data.numel()) + exploration_module.step(data.numel()) updater.step() with set_exploration_type(ExplorationType.MODE), torch.no_grad(): diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 63802cf0052..66a3e080ab5 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -344,12 +344,21 @@ def assert0(x): # # Fixed batch-size # ~~~~~~~~~~~~~~~~ -# If the batch-size is passed during construction, it should be ommited when +# If the batch-size is passed during construction, it should be omited when # sampling: +data = MyData( + images=torch.randint( + 255, + (1000, 64, 64, 3), + ), + labels=torch.randint(100, (1000,)), + batch_size=[1000], +) + buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128) -buffer_lazymemmap.extend(data) -buffer_lazymemmap.sample() +buffer_lazymemmap.add(data) +buffer_lazymemmap.sample() # will produces 128 identical samples ###################################################################### @@ -363,7 +372,7 @@ def assert0(x): buffer_lazymemmap = ReplayBuffer( storage=LazyMemmapStorage(size), batch_size=128, prefetch=10 ) # creates a queue of 10 elements to be prefetched in the background -buffer_lazymemmap.extend(data) +buffer_lazymemmap.add(data) print(buffer_lazymemmap.sample()) diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index 84a5a499723..77fd7ba49f7 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -575,7 +575,7 @@ def env_make(env_name): parallel_env = ParallelEnv( 2, [env_make, env_make], - [{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], + create_env_kwargs=[{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], ) tensordict = parallel_env.reset() @@ -619,7 +619,7 @@ def env_make(env_name): parallel_env = ParallelEnv( 2, [env_make, env_make], - [{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], + create_env_kwargs=[{"env_name": "ALE/AirRaid-v5"}, {"env_name": "ALE/Pong-v5"}], ) parallel_env = TransformedEnv(parallel_env, GrayScale()) # transforms on main process tensordict = parallel_env.reset() From 71b9ff65a3d870dc26e7de8e404c1e2187b4cafe Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 2 Feb 2024 08:15:13 +0000 Subject: [PATCH 4/9] amend --- docs/source/conf.py | 4 +++ tutorials/sphinx-tutorials/coding_ddpg.py | 7 +++- tutorials/sphinx-tutorials/coding_dqn.py | 32 ++++++++++++++----- tutorials/sphinx-tutorials/coding_ppo.py | 7 +++- tutorials/sphinx-tutorials/dqn_with_rnn.py | 7 +++- tutorials/sphinx-tutorials/multi_task.py | 7 +++- tutorials/sphinx-tutorials/pendulum.py | 7 +++- .../sphinx-tutorials/pretrained_models.py | 7 +++- tutorials/sphinx-tutorials/rb_tutorial.py | 7 +++- tutorials/sphinx-tutorials/torchrl_demo.py | 7 +++- tutorials/sphinx-tutorials/torchrl_envs.py | 7 +++- 11 files changed, 82 insertions(+), 17 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index f0821ede0bf..2b57387caf9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -189,3 +189,7 @@ generate_tutorial_references("../../tutorials/sphinx-tutorials/", "tutorial") # generate_tutorial_references("../../tutorials/src/", "src") generate_tutorial_references("../../tutorials/media/", "media") + +# We do this to indicate that the script is run by sphinx +import builtins +builtins.__sphinx_build__ = True diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index a131235d1d8..3fb4560240a 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -73,7 +73,12 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: pass diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index d543dd0d876..e4e27b5ae1c 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -96,19 +96,23 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: pass # sphinx_gallery_end_ignore - import os import uuid import torch from torch import nn -from torchrl.collectors import MultiaSyncDataCollector +from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer from torchrl.envs import ( EnvCreator, @@ -332,6 +336,7 @@ def make_model(dummy_env): # we join our actor with an EGreedyModule for data collection exploration_module = EGreedyModule( + spec=dummy_env.action_spec, annealing_num_steps=total_frames, eps_init=eps_greedy_val, eps_end=eps_greedy_val_env, @@ -383,6 +388,13 @@ def get_replay_buffer(buffer_size, n_optim, batch_size): # We choose the following configuration: we will be running a series of # parallel environments synchronously in parallel in different collectors, # themselves running in parallel but asynchronously. +# +# .. note:: +# This feature is only available when running the code within the "spawn" +# start method of python multiprocessing library. If this tutorial is run +# directly as a script (thereby using the "fork" method) we will be using +# a regular :class:`~torchrl.collectors.SyncDataCollector`. +# # The advantage of this configuration is that we can balance the amount of # compute that is executed in batch with what we want to be executed # asynchronously. We encourage the reader to experiment how the collection @@ -411,11 +423,15 @@ def get_collector( total_frames, device, ): - data_collector = MultiaSyncDataCollector( - [ - make_env(parallel=True, obs_norm_sd=stats), - ] - * num_collectors, + is_fork = multiprocessing.get_start_method() == "fork" + if is_fork: + cls = SyncDataCollector + env_arg = make_env(parallel=True, obs_norm_sd=stats) + else: + cls = MultiaSyncDataCollector + env_arg = [make_env(parallel=True, obs_norm_sd=stats)]*num_collectors + data_collector = cls( + env_arg, policy=actor_explore, frames_per_batch=frames_per_batch, total_frames=total_frames, diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 5c157a93a97..e7e6c8fd7f1 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -114,7 +114,12 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: pass diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index 8b4022ded39..ed23ee085d3 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -78,7 +78,12 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: pass diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index b3ac87cc921..68cb995a1a3 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -20,7 +20,12 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: pass diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index a9f89641183..a67976566d5 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -84,7 +84,12 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: pass diff --git a/tutorials/sphinx-tutorials/pretrained_models.py b/tutorials/sphinx-tutorials/pretrained_models.py index f67475d16bb..0d78a8ec6a1 100644 --- a/tutorials/sphinx-tutorials/pretrained_models.py +++ b/tutorials/sphinx-tutorials/pretrained_models.py @@ -24,7 +24,12 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: pass diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 66a3e080ab5..98529507c7f 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -57,7 +57,12 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: pass diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index f137a499ef4..89bb1680bf8 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -135,7 +135,12 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: pass diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index 77fd7ba49f7..56896637a87 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -37,7 +37,12 @@ # `__main__` method call, but for the easy of reading the code switch to fork # which is also a default spawn method in Google's Colaboratory try: - multiprocessing.set_start_method("fork") + is_sphinx = __sphinx_build__ +except NameError: + is_sphinx = False + +try: + multiprocessing.set_start_method("spawn" if is_sphinx else "fork") except RuntimeError: pass From 54ed708560e4a2780811eb99ad089e63744ceae3 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 2 Feb 2024 08:49:25 +0000 Subject: [PATCH 5/9] amend --- docs/source/conf.py | 1 + tutorials/sphinx-tutorials/coding_dqn.py | 12 ++++------- tutorials/sphinx-tutorials/rb_tutorial.py | 26 +++++++++++------------ 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 2b57387caf9..060103b48b4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -192,4 +192,5 @@ # We do this to indicate that the script is run by sphinx import builtins + builtins.__sphinx_build__ = True diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index e4e27b5ae1c..6f37bdc2a17 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -112,7 +112,7 @@ import torch from torch import nn -from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector +from torchrl.collectors import MultiaSyncDataCollector from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer from torchrl.envs import ( EnvCreator, @@ -276,6 +276,7 @@ def get_norm_stats(): # let's check that normalizing constants have a size of ``[C, 1, 1]`` where # ``C=4`` (because of :class:`~torchrl.envs.CatFrames`). print("state dict of the observation norm:", obs_norm_sd) + test_env.close() return obs_norm_sd @@ -423,13 +424,8 @@ def get_collector( total_frames, device, ): - is_fork = multiprocessing.get_start_method() == "fork" - if is_fork: - cls = SyncDataCollector - env_arg = make_env(parallel=True, obs_norm_sd=stats) - else: - cls = MultiaSyncDataCollector - env_arg = [make_env(parallel=True, obs_norm_sd=stats)]*num_collectors + cls = MultiaSyncDataCollector + env_arg = [make_env(parallel=True, obs_norm_sd=stats)] * num_collectors data_collector = cls( env_arg, policy=actor_explore, diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 98529507c7f..3d37ce3de83 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -138,7 +138,7 @@ from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage # We define the maximum size of the buffer -size = 10_000 +size = 100 ###################################################################### # A buffer with a list storage buffer can store any kind of data (but we must @@ -265,10 +265,10 @@ class MyData: data = MyData( images=torch.randint( 255, - (1000, 64, 64, 3), + (10, 64, 64, 3), ), - labels=torch.randint(100, (1000,)), - batch_size=[1000], + labels=torch.randint(100, (10,)), + batch_size=[10], ) tempdir = tempfile.TemporaryDirectory() @@ -308,7 +308,7 @@ def transform(x): # Let's build our replay buffer on disk: -rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=transform) +rb = ReplayBuffer(storage=LazyMemmapStorage(size), transform=transform) data = { "a": torch.randn(3), "b": {"c": (torch.zeros(2), [torch.ones(1)])}, @@ -355,10 +355,10 @@ def assert0(x): data = MyData( images=torch.randint( 255, - (1000, 64, 64, 3), + (10, 64, 64, 3), ), - labels=torch.randint(100, (1000,)), - batch_size=[1000], + labels=torch.randint(100, (10,)), + batch_size=[10], ) buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128) @@ -411,10 +411,10 @@ def assert0(x): # we create a data that is big enough to get a couple of samples data = TensorDict( { - "a": torch.arange(512).view(128, 4), - ("b", "c"): torch.arange(1024).view(128, 8), + "a": torch.arange(64).view(16, 4), + ("b", "c"): torch.arange(128).view(16, 8), }, - batch_size=[128], + batch_size=[16], ) buffer_lazymemmap.extend(data) @@ -457,7 +457,7 @@ def assert0(x): from torchrl.data.replay_buffers.samplers import PrioritizedSampler -size = 1000 +size = 100 rb = ReplayBuffer( storage=ListStorage(size), @@ -732,7 +732,7 @@ def assert0(x): GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ) -rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) +rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(size), transform=t, batch_size=16) data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) rb.add(data_exclude) From 651b1cb4247c7a80982e8ed73e7682704b0fe07e Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 2 Feb 2024 08:52:09 +0000 Subject: [PATCH 6/9] amend --- tutorials/sphinx-tutorials/coding_ddpg.py | 3 ++- tutorials/sphinx-tutorials/coding_dqn.py | 3 ++- tutorials/sphinx-tutorials/coding_ppo.py | 3 ++- tutorials/sphinx-tutorials/dqn_with_rnn.py | 3 ++- tutorials/sphinx-tutorials/multiagent_ppo.py | 5 ++++- tutorials/sphinx-tutorials/pretrained_models.py | 3 ++- 6 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 3fb4560240a..b75e2facbeb 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -91,7 +91,8 @@ ############################################################################### # We will execute the policy on CUDA if available -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +is_fork = multiprocessing.get_start_method() == "fork" +device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA ############################################################################### diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 6f37bdc2a17..2a62327a39b 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -478,7 +478,8 @@ def get_loss_module(actor, gamma): # in practice, and the performance of the algorithm should hopefully not be # too sensitive to slight variations of these. -device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu" +is_fork = multiprocessing.get_start_method() == "fork" +device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") ############################################################################### # Optimizer diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index e7e6c8fd7f1..38a40c7725b 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -164,7 +164,8 @@ # actually return ``frame_skip`` frames). # -device = "cpu" if not torch.cuda.is_available() else "cuda:0" +is_fork = multiprocessing.get_start_method() == "fork" +device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") num_cells = 256 # number of cells in each layer i.e. output dim. lr = 3e-4 max_grad_norm = 1.0 diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index ed23ee085d3..a0f1d32a436 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -116,7 +116,8 @@ from torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule from torchrl.objectives import DQNLoss, SoftUpdate -device = torch.device(0) if torch.cuda.device_count() else torch.device("cpu") +is_fork = multiprocessing.get_start_method() == "fork" +device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") ###################################################################### # Environment diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index 90fd82dab3c..a1bb2420ed5 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -123,6 +123,8 @@ import torch # Tensordict modules +from torch import multiprocessing + from tensordict.nn import TensorDictModule from tensordict.nn.distributions import NormalParamExtractor @@ -161,7 +163,8 @@ # # Devices -device = "cpu" if not torch.has_cuda else "cuda:0" # The divice where learning is run +is_fork = multiprocessing.get_start_method() == "fork" +device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") vmas_device = device # The device where the simulator is run (VMAS can run on GPU) # Sampling diff --git a/tutorials/sphinx-tutorials/pretrained_models.py b/tutorials/sphinx-tutorials/pretrained_models.py index 0d78a8ec6a1..0d086d69a82 100644 --- a/tutorials/sphinx-tutorials/pretrained_models.py +++ b/tutorials/sphinx-tutorials/pretrained_models.py @@ -42,7 +42,8 @@ from torchrl.envs.libs.gym import GymEnv from torchrl.modules import Actor -device = "cuda:0" if torch.cuda.device_count() else "cpu" +is_fork = multiprocessing.get_start_method() == "fork" +device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") ############################################################################## # Let us first create an environment. For the sake of simplicity, we will be using From 851654093ef95f696b1f33a8a0151e3de8be872a Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 2 Feb 2024 11:10:30 +0000 Subject: [PATCH 7/9] amend --- tutorials/sphinx-tutorials/torchrl_demo.py | 221 ++++++++++++++------- 1 file changed, 152 insertions(+), 69 deletions(-) diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 89bb1680bf8..b112aa64a50 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -32,75 +32,158 @@ # **Content**: # .. aafig:: # -# "torchrl" -# │ -# ├── "collectors" -# │ └── "collectors.py" -# ├── "data" -# │ ├── "tensor_specs.py" -# │ ├── "postprocs" -# │ │ └── "postprocs.py" -# │ └── "replay_buffers" -# │ ├── "replay_buffers.py" -# │ └── "storages.py" -# ├── "envs" -# │ ├── "common.py" -# │ ├── "env_creator.py" -# │ ├── "gym_like.py" -# │ ├── "vec_env.py" -# │ ├── "libs" -# │ │ ├── "dm_control.py" -# │ │ └── "gym.py" -# │ └── "transforms" -# │ ├── "functional.py" -# │ └── "transforms.py" -# ├── "modules" -# │ ├── "distributions" -# │ │ ├── "continuous.py" -# │ │ └── "discrete.py" -# │ ├── "models" -# │ │ ├── "models.py" -# │ │ └── "exploration.py" -# │ └── "tensordict_module" -# │ ├── "actors.py" -# │ ├── "common.py" -# │ ├── "exploration.py" -# │ ├── "probabilistic.py" -# │ └── "sequence.py" -# ├── "objectives" -# │ ├── "common.py" -# │ ├── "ddpg.py" -# │ ├── "dqn.py" -# │ ├── "functional.py" -# │ ├── "ppo.py" -# │ ├── "redq.py" -# │ ├── "reinforce.py" -# │ ├── "sac.py" -# │ ├── "utils.py" -# │ └── "value" -# │ ├── "advantages.py" -# │ ├── "functional.py" -# │ ├── "pg.py" -# │ ├── "utils.py" -# │ └── "vtrace.py" -# ├── "record" -# │ └── "recorder.py" -# └── "trainers" -# ├── "loggers" -# │ ├── "common.py" -# │ ├── "csv.py" -# │ ├── "mlflow.py" -# │ ├── "tensorboard.py" -# │ └── "wandb.py" -# ├── "trainers.py" -# └── "helpers" -# ├── "collectors.py" -# ├── "envs.py" -# ├── "loggers.py" -# ├── "losses.py" -# ├── "models.py" -# ├── "replay_buffer.py" -# └── "trainers.py" +# ├── "torchrl" +# │ │ +# │ ├── "collectors" +# │ │ └── "collectors.py" +# │ │ │ +# │ │ ├── "distributed" +# │ │ │ └── "default_configs.py" +# │ │ │ └── "generic.py" +# │ │ │ └── "ray.py" +# │ │ │ └── "rpc.py" +# │ │ │ └── "sync.py" +# │ │ +# │ ├── "data" +# │ │ │ +# │ │ ├── "datasets" +# │ │ │ └── "atari_dqn.py" +# │ │ │ └── "d4rl.py" +# │ │ │ └── "d4rl_infos.py" +# │ │ │ └── "gen_dgrl.py" +# │ │ │ └── "minari_data.py" +# │ │ │ └── "openml.py" +# │ │ │ └── "openx.py" +# │ │ │ └── "roboset.py" +# │ │ │ └── "vd4rl.py" +# │ │ │ +# │ │ ├── "postprocs" +# │ │ │ └── "postprocs.py" +# │ │ │ +# │ │ ├── "replay_buffers" +# │ │ │ └── "replay_buffers.py" +# │ │ │ └── "samplers.py" +# │ │ │ └── "storages.py" +# │ │ │ └── "transforms.py" +# │ │ │ └── "writers.py" +# │ │ │ +# │ │ ├── "rlhf" +# │ │ │ └── "dataset.py" +# │ │ │ └── "prompt.py" +# │ │ │ └── "reward.py" +# │ │ └── "tensor_specs.py" +# │ │ +# │ ├── "envs" +# │ │ └── "batched_envs.py" +# │ │ └── "common.py" +# │ │ └── "env_creator.py" +# │ │ └── "gym_like.py" +# │ │ │ +# │ │ ├── "libs" +# │ │ │ └── "brax.py" +# │ │ │ └── "dm_control.py" +# │ │ │ └── "envpool.py" +# │ │ │ └── "gym.py" +# │ │ │ └── "habitat.py" +# │ │ │ └── "isaacgym.py" +# │ │ │ └── "jumanji.py" +# │ │ │ └── "openml.py" +# │ │ │ └── "pettingzoo.py" +# │ │ │ └── "robohive.py" +# │ │ │ └── "smacv2.py" +# │ │ │ └── "vmas.py" +# │ │ │ +# │ │ ├── "model_based" +# │ │ │ └── "common.py" +# │ │ │ └── "dreamer.py" +# │ │ │ +# │ │ ├── "transforms" +# │ │ │ └── "functional.py" +# │ │ │ └── "gym_transforms.py" +# │ │ │ └── "r3m.py" +# │ │ │ └── "rlhf.py" +# │ │ │ └── "transforms.py" +# │ │ │ └── "vc1.py" +# │ │ │ └── "vip.py" +# │ │ └── "vec_envs.py" +# │ │ +# │ ├── "modules" +# │ │ │ +# │ │ ├── "distributions" +# │ │ │ └── "continuous.py" +# │ │ │ └── "discrete.py" +# │ │ │ └── "truncated_normal.py" +# │ │ │ +# │ │ ├── "models" +# │ │ │ └── "decision_transformer.py" +# │ │ │ └── "exploration.py" +# │ │ │ └── "model_based.py" +# │ │ │ └── "models.py" +# │ │ │ └── "multiagent.py" +# │ │ │ └── "rlhf.py" +# │ │ │ +# │ │ ├── "planners" +# │ │ │ └── "cem.py" +# │ │ │ └── "common.py" +# │ │ │ └── "mppi.py" +# │ │ │ +# │ │ ├── "tensordict_module" +# │ │ │ └── "actors.py" +# │ │ │ └── "common.py" +# │ │ │ └── "exploration.py" +# │ │ │ └── "probabilistic.py" +# │ │ │ └── "rnn.py" +# │ │ │ └── "sequence.py" +# │ │ │ └── "world_models.py" +# │ │ +# │ ├── "objectives" +# │ │ └── "a2c.py" +# │ │ └── "common.py" +# │ │ └── "cql.py" +# │ │ └── "ddpg.py" +# │ │ └── "decision_transformer.py" +# │ │ └── "deprecated.py" +# │ │ └── "dqn.py" +# │ │ └── "dreamer.py" +# │ │ └── "functional.py" +# │ │ └── "iql.py" +# │ │ │ +# │ │ ├── "multiagent" +# │ │ │ └── "qmixer.py" +# │ │ └── "ppo.py" +# │ │ └── "redq.py" +# │ │ └── "reinforce.py" +# │ │ └── "sac.py" +# │ │ └── "td3.py" +# │ │ │ +# │ │ ├── "value" +# │ │ │ └── "advantages.py" +# │ │ │ └── "functional.py" +# │ │ │ └── "pg.py" +# │ │ +# │ ├── "record" +# │ │ │ +# │ │ ├── "loggers" +# │ │ │ └── "common.py" +# │ │ │ └── "csv.py" +# │ │ │ └── "mlflow.py" +# │ │ │ └── "tensorboard.py" +# │ │ │ └── "wandb.py" +# │ │ └── "recorder.py" +# │ │ +# │ ├── "trainers" +# │ │ │ +# │ │ ├── "helpers" +# │ │ │ └── "collectors.py" +# │ │ │ └── "envs.py" +# │ │ │ └── "logger.py" +# │ │ │ └── "losses.py" +# │ │ │ └── "models.py" +# │ │ │ └── "replay_buffer.py" +# │ │ │ └── "trainers.py" +# │ │ └── "trainers.py" +# │ └── "version.py" +# # # Unlike other domains, RL is less about media than *algorithms*. As such, it # is harder to make truly independent components. From af83e067954ed6ef8df3689cd9a2d6116b23f0cf Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 2 Feb 2024 11:10:46 +0000 Subject: [PATCH 8/9] amend --- tutorials/sphinx-tutorials/coding_ddpg.py | 6 +++++- tutorials/sphinx-tutorials/coding_dqn.py | 6 +++++- tutorials/sphinx-tutorials/coding_ppo.py | 6 +++++- tutorials/sphinx-tutorials/dqn_with_rnn.py | 6 +++++- tutorials/sphinx-tutorials/multiagent_ppo.py | 12 ++++++++---- tutorials/sphinx-tutorials/pretrained_models.py | 6 +++++- 6 files changed, 33 insertions(+), 9 deletions(-) diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index b75e2facbeb..5f8bf2c0830 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -92,7 +92,11 @@ ############################################################################### # We will execute the policy on CUDA if available is_fork = multiprocessing.get_start_method() == "fork" -device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA ############################################################################### diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 2a62327a39b..f85f6bf1e14 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -479,7 +479,11 @@ def get_loss_module(actor, gamma): # too sensitive to slight variations of these. is_fork = multiprocessing.get_start_method() == "fork" -device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) ############################################################################### # Optimizer diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 38a40c7725b..be82bbd3bd8 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -165,7 +165,11 @@ # is_fork = multiprocessing.get_start_method() == "fork" -device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) num_cells = 256 # number of cells in each layer i.e. output dim. lr = 3e-4 max_grad_norm = 1.0 diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index a0f1d32a436..b71a112c91a 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -117,7 +117,11 @@ from torchrl.objectives import DQNLoss, SoftUpdate is_fork = multiprocessing.get_start_method() == "fork" -device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) ###################################################################### # Environment diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index a1bb2420ed5..7451d6b39e7 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -122,12 +122,12 @@ # Torch import torch -# Tensordict modules -from torch import multiprocessing - from tensordict.nn import TensorDictModule from tensordict.nn.distributions import NormalParamExtractor +# Tensordict modules +from torch import multiprocessing + # Data collection from torchrl.collectors import SyncDataCollector from torchrl.data.replay_buffers import ReplayBuffer @@ -164,7 +164,11 @@ # Devices is_fork = multiprocessing.get_start_method() == "fork" -device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) vmas_device = device # The device where the simulator is run (VMAS can run on GPU) # Sampling diff --git a/tutorials/sphinx-tutorials/pretrained_models.py b/tutorials/sphinx-tutorials/pretrained_models.py index 0d086d69a82..03265c50d2b 100644 --- a/tutorials/sphinx-tutorials/pretrained_models.py +++ b/tutorials/sphinx-tutorials/pretrained_models.py @@ -43,7 +43,11 @@ from torchrl.modules import Actor is_fork = multiprocessing.get_start_method() == "fork" -device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") +device = ( + torch.device(0) + if torch.cuda.is_available() and not is_fork + else torch.device("cpu") +) ############################################################################## # Let us first create an environment. For the sake of simplicity, we will be using From 3cd84793433279b7b0b6c40088d04351df4f7d67 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 2 Feb 2024 11:53:44 +0000 Subject: [PATCH 9/9] amend --- tutorials/sphinx-tutorials/torchrl_demo.py | 284 ++++++++++----------- 1 file changed, 132 insertions(+), 152 deletions(-) diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index b112aa64a50..5e00442fe36 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -32,158 +32,138 @@ # **Content**: # .. aafig:: # -# ├── "torchrl" -# │ │ -# │ ├── "collectors" -# │ │ └── "collectors.py" -# │ │ │ -# │ │ ├── "distributed" -# │ │ │ └── "default_configs.py" -# │ │ │ └── "generic.py" -# │ │ │ └── "ray.py" -# │ │ │ └── "rpc.py" -# │ │ │ └── "sync.py" -# │ │ -# │ ├── "data" -# │ │ │ -# │ │ ├── "datasets" -# │ │ │ └── "atari_dqn.py" -# │ │ │ └── "d4rl.py" -# │ │ │ └── "d4rl_infos.py" -# │ │ │ └── "gen_dgrl.py" -# │ │ │ └── "minari_data.py" -# │ │ │ └── "openml.py" -# │ │ │ └── "openx.py" -# │ │ │ └── "roboset.py" -# │ │ │ └── "vd4rl.py" -# │ │ │ -# │ │ ├── "postprocs" -# │ │ │ └── "postprocs.py" -# │ │ │ -# │ │ ├── "replay_buffers" -# │ │ │ └── "replay_buffers.py" -# │ │ │ └── "samplers.py" -# │ │ │ └── "storages.py" -# │ │ │ └── "transforms.py" -# │ │ │ └── "writers.py" -# │ │ │ -# │ │ ├── "rlhf" -# │ │ │ └── "dataset.py" -# │ │ │ └── "prompt.py" -# │ │ │ └── "reward.py" -# │ │ └── "tensor_specs.py" -# │ │ -# │ ├── "envs" -# │ │ └── "batched_envs.py" -# │ │ └── "common.py" -# │ │ └── "env_creator.py" -# │ │ └── "gym_like.py" -# │ │ │ -# │ │ ├── "libs" -# │ │ │ └── "brax.py" -# │ │ │ └── "dm_control.py" -# │ │ │ └── "envpool.py" -# │ │ │ └── "gym.py" -# │ │ │ └── "habitat.py" -# │ │ │ └── "isaacgym.py" -# │ │ │ └── "jumanji.py" -# │ │ │ └── "openml.py" -# │ │ │ └── "pettingzoo.py" -# │ │ │ └── "robohive.py" -# │ │ │ └── "smacv2.py" -# │ │ │ └── "vmas.py" -# │ │ │ -# │ │ ├── "model_based" -# │ │ │ └── "common.py" -# │ │ │ └── "dreamer.py" -# │ │ │ -# │ │ ├── "transforms" -# │ │ │ └── "functional.py" -# │ │ │ └── "gym_transforms.py" -# │ │ │ └── "r3m.py" -# │ │ │ └── "rlhf.py" -# │ │ │ └── "transforms.py" -# │ │ │ └── "vc1.py" -# │ │ │ └── "vip.py" -# │ │ └── "vec_envs.py" -# │ │ -# │ ├── "modules" -# │ │ │ -# │ │ ├── "distributions" -# │ │ │ └── "continuous.py" -# │ │ │ └── "discrete.py" -# │ │ │ └── "truncated_normal.py" -# │ │ │ -# │ │ ├── "models" -# │ │ │ └── "decision_transformer.py" -# │ │ │ └── "exploration.py" -# │ │ │ └── "model_based.py" -# │ │ │ └── "models.py" -# │ │ │ └── "multiagent.py" -# │ │ │ └── "rlhf.py" -# │ │ │ -# │ │ ├── "planners" -# │ │ │ └── "cem.py" -# │ │ │ └── "common.py" -# │ │ │ └── "mppi.py" -# │ │ │ -# │ │ ├── "tensordict_module" -# │ │ │ └── "actors.py" -# │ │ │ └── "common.py" -# │ │ │ └── "exploration.py" -# │ │ │ └── "probabilistic.py" -# │ │ │ └── "rnn.py" -# │ │ │ └── "sequence.py" -# │ │ │ └── "world_models.py" -# │ │ -# │ ├── "objectives" -# │ │ └── "a2c.py" -# │ │ └── "common.py" -# │ │ └── "cql.py" -# │ │ └── "ddpg.py" -# │ │ └── "decision_transformer.py" -# │ │ └── "deprecated.py" -# │ │ └── "dqn.py" -# │ │ └── "dreamer.py" -# │ │ └── "functional.py" -# │ │ └── "iql.py" -# │ │ │ -# │ │ ├── "multiagent" -# │ │ │ └── "qmixer.py" -# │ │ └── "ppo.py" -# │ │ └── "redq.py" -# │ │ └── "reinforce.py" -# │ │ └── "sac.py" -# │ │ └── "td3.py" -# │ │ │ -# │ │ ├── "value" -# │ │ │ └── "advantages.py" -# │ │ │ └── "functional.py" -# │ │ │ └── "pg.py" -# │ │ -# │ ├── "record" -# │ │ │ -# │ │ ├── "loggers" -# │ │ │ └── "common.py" -# │ │ │ └── "csv.py" -# │ │ │ └── "mlflow.py" -# │ │ │ └── "tensorboard.py" -# │ │ │ └── "wandb.py" -# │ │ └── "recorder.py" -# │ │ -# │ ├── "trainers" -# │ │ │ -# │ │ ├── "helpers" -# │ │ │ └── "collectors.py" -# │ │ │ └── "envs.py" -# │ │ │ └── "logger.py" -# │ │ │ └── "losses.py" -# │ │ │ └── "models.py" -# │ │ │ └── "replay_buffer.py" -# │ │ │ └── "trainers.py" -# │ │ └── "trainers.py" -# │ └── "version.py" -# +# "torchrl" +# │ +# ├── "collectors" +# │ └── "collectors.py" +# │ │ +# │ └── "distributed" +# │ └── "default_configs.py" +# │ └── "generic.py" +# │ └── "ray.py" +# │ └── "rpc.py" +# │ └── "sync.py" +# ├── "data" +# │ │ +# │ ├── "datasets" +# │ │ └── "atari_dqn.py" +# │ │ └── "d4rl.py" +# │ │ └── "d4rl_infos.py" +# │ │ └── "gen_dgrl.py" +# │ │ └── "minari_data.py" +# │ │ └── "openml.py" +# │ │ └── "openx.py" +# │ │ └── "roboset.py" +# │ │ └── "vd4rl.py" +# │ ├── "postprocs" +# │ │ └── "postprocs.py" +# │ ├── "replay_buffers" +# │ │ └── "replay_buffers.py" +# │ │ └── "samplers.py" +# │ │ └── "storages.py" +# │ │ └── "transforms.py" +# │ │ └── "writers.py" +# │ ├── "rlhf" +# │ │ └── "dataset.py" +# │ │ └── "prompt.py" +# │ │ └── "reward.py" +# │ └── "tensor_specs.py" +# ├── "envs" +# │ └── "batched_envs.py" +# │ └── "common.py" +# │ └── "env_creator.py" +# │ └── "gym_like.py" +# │ ├── "libs" +# │ │ └── "brax.py" +# │ │ └── "dm_control.py" +# │ │ └── "envpool.py" +# │ │ └── "gym.py" +# │ │ └── "habitat.py" +# │ │ └── "isaacgym.py" +# │ │ └── "jumanji.py" +# │ │ └── "openml.py" +# │ │ └── "pettingzoo.py" +# │ │ └── "robohive.py" +# │ │ └── "smacv2.py" +# │ │ └── "vmas.py" +# │ ├── "model_based" +# │ │ └── "common.py" +# │ │ └── "dreamer.py" +# │ ├── "transforms" +# │ │ └── "functional.py" +# │ │ └── "gym_transforms.py" +# │ │ └── "r3m.py" +# │ │ └── "rlhf.py" +# │ │ └── "transforms.py" +# │ │ └── "vc1.py" +# │ │ └── "vip.py" +# │ └── "vec_envs.py" +# ├── "modules" +# │ ├── "distributions" +# │ │ └── "continuous.py" +# │ │ └── "discrete.py" +# │ │ └── "truncated_normal.py" +# │ ├── "models" +# │ │ └── "decision_transformer.py" +# │ │ └── "exploration.py" +# │ │ └── "model_based.py" +# │ │ └── "models.py" +# │ │ └── "multiagent.py" +# │ │ └── "rlhf.py" +# │ ├── "planners" +# │ │ └── "cem.py" +# │ │ └── "common.py" +# │ │ └── "mppi.py" +# │ └── "tensordict_module" +# │ └── "actors.py" +# │ └── "common.py" +# │ └── "exploration.py" +# │ └── "probabilistic.py" +# │ └── "rnn.py" +# │ └── "sequence.py" +# │ └── "world_models.py" +# ├── "objectives" +# │ └── "a2c.py" +# │ └── "common.py" +# │ └── "cql.py" +# │ └── "ddpg.py" +# │ └── "decision_transformer.py" +# │ └── "deprecated.py" +# │ └── "dqn.py" +# │ └── "dreamer.py" +# │ └── "functional.py" +# │ └── "iql.py" +# │ ├── "multiagent" +# │ │ └── "qmixer.py" +# │ └── "ppo.py" +# │ └── "redq.py" +# │ └── "reinforce.py" +# │ └── "sac.py" +# │ └── "td3.py" +# │ ├── "value" +# │ └── "advantages.py" +# │ └── "functional.py" +# │ └── "pg.py" +# ├── "record" +# │ ├── "loggers" +# │ │ └── "common.py" +# │ │ └── "csv.py" +# │ │ └── "mlflow.py" +# │ │ └── "tensorboard.py" +# │ │ └── "wandb.py" +# │ └── "recorder.py" +# ├── "trainers" +# │ │ +# │ ├── "helpers" +# │ │ └── "collectors.py" +# │ │ └── "envs.py" +# │ │ └── "logger.py" +# │ │ └── "losses.py" +# │ │ └── "models.py" +# │ │ └── "replay_buffer.py" +# │ │ └── "trainers.py" +# │ └── "trainers.py" +# └── "version.py" # # Unlike other domains, RL is less about media than *algorithms*. As such, it # is harder to make truly independent components.