From 3d7d9eee59d95d711087c9ace77d2d86fe6b92c4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 9 Feb 2024 15:30:28 +0000 Subject: [PATCH] amend --- torchrl/objectives/utils.py | 3 +- tutorials/sphinx-tutorials/coding_ddpg.py | 2 + tutorials/sphinx-tutorials/coding_dqn.py | 2 + tutorials/sphinx-tutorials/coding_ppo.py | 2 + .../sphinx-tutorials/getting-started-1.py | 4 +- .../sphinx-tutorials/getting-started-2.py | 79 +++++++++++++++++-- .../sphinx-tutorials/getting-started-5.py | 10 ++- 7 files changed, 90 insertions(+), 12 deletions(-) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 43dfa65c0c4..b234af6a804 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -300,8 +300,7 @@ def __init__( ): if eps is None and tau is None: raise RuntimeError( - "Neither eps nor tau was provided. " "This behaviour is deprecated.", - category=DeprecationWarning, + "Neither eps nor tau was provided. This behaviour is deprecated.", ) eps = 0.999 if (eps is None) ^ (tau is None): diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 5f8bf2c0830..252b4fd2146 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -4,6 +4,8 @@ ====================================== **Author**: `Vincent Moens `_ +.. _coding_ddpg: + """ ############################################################################## diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 70e7e29fe59..eb476dfcc15 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -4,6 +4,8 @@ ============================== **Author**: `Vincent Moens `_ +.. _coding_dqn: + """ ############################################################################## diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index be82bbd3bd8..6f31a0aed1a 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -4,6 +4,8 @@ ================================================== **Author**: `Vincent Moens `_ +.. _coding_ppo: + This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to train a parametric policy network to solve the Inverted Pendulum task from the `OpenAI-Gym/Farama-Gymnasium control library `__. diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py index 4be8f963f3f..1a133b8c2ef 100644 --- a/tutorials/sphinx-tutorials/getting-started-1.py +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -198,7 +198,9 @@ # to :meth:`~torchrl.modules.EGreedyModule.step` is required (see the last # :ref:`tutorial ` for an example). # -exploration_module = EGreedyModule(spec=env.action_spec, annealing_num_steps=1000, eps_init=0.5) +exploration_module = EGreedyModule( + spec=env.action_spec, annealing_num_steps=1000, eps_init=0.5 +) ################################### # To build our explorative policy, we only had to concatenate the diff --git a/tutorials/sphinx-tutorials/getting-started-2.py b/tutorials/sphinx-tutorials/getting-started-2.py index 8b2471e0a33..a75ac8ed2b9 100644 --- a/tutorials/sphinx-tutorials/getting-started-2.py +++ b/tutorials/sphinx-tutorials/getting-started-2.py @@ -48,15 +48,18 @@ env = GymEnv("Pendulum-v1") -from torchrl.objectives import DDPGLoss from torchrl.modules import Actor, MLP, ValueOperator +from torchrl.objectives import DDPGLoss -n_obs = env.observation_spec['observation'].shape[-1] +n_obs = env.observation_spec["observation"].shape[-1] n_act = env.action_spec.shape[-1] actor = Actor(MLP(in_features=n_obs, out_features=n_act, num_cells=[32, 32])) -value_net = ValueOperator(MLP(in_features=n_obs+n_act, out_features=1, num_cells=[32, 32]), in_keys=["observation", "action"]) +value_net = ValueOperator( + MLP(in_features=n_obs + n_act, out_features=1, num_cells=[32, 32]), + in_keys=["observation", "action"], +) -loss = DDPGLoss(actor_network=actor, value_network=value_net) +ddpg_loss = DDPGLoss(actor_network=actor, value_network=value_net) ################################### # And that is it! Our loss module can now be run with data coming from the @@ -65,10 +68,13 @@ # rollout = env.rollout(max_steps=100, policy=actor) -loss_vals = loss(rollout) +loss_vals = ddpg_loss(rollout) print(loss_vals) ################################### +# LossModule's output +# ------------------- +# # As you can see, the value we received from the loss isn't a single scalar # but a dictionary containing multiple losses. # @@ -76,4 +82,65 @@ # and since some users may wish to separate the optimization of each module # in distinct steps, TorchRL's objectives will return dictionaries containing # the various loss components. -# \ No newline at end of file +# +# This format also allows us to pass metadata along with the loss values. In +# general, we make sure that only the loss values are differentiable such that +# you can simply sum over the values of the dictionary to obtain the total +# loss. If you want to make sure you're fully in control of what is happening, +# you can sum over only the entries which keys start with the ``"loss_"`` prefix: +# +total_loss = 0 +for key, val in loss_vals.items(): + if key.startswith("loss_"): + total_loss += val + +################################### +# Given all this, training the modules is not so different than what would be +# done in any other training loop. We'll need an optimizer (or one optimizer +# per module if that is your choice). The following items will typically be +# found in your training loop: + +from torch.optim import Adam + +optim = Adam(ddpg_loss.parameters()) +total_loss.backward() +optim.step() +optim.zero_grad() + +################################### +# Further considerations: Target parameters +# ----------------------------------------- +# +# Another important consideration is that off-policy algorithms such as DDPG +# typically have target parameters associated with them. Target parameters are +# usually a version of the parameters that lags in time (or a smoothed +# average of that) and they are used for value estimation when training the +# policy. Training the policy using target parameters is usually much more +# efficient than using the configuraton of the value network parameters at the +# same time. You usually don't need to care too much about target parameters +# as your loss module will create them for you, **but** it is your +# responsibility to update these values when needed depending on your needs. +# TorchRL provides a couple of updaters, namely +# :class:`~torchrl.objectives.HardUpdate` and +# :class:`~torchrl.objectives.SoftUpdate`. Instantiating them is very easy and +# doesn't require any knowledge about the inner machinery of the loss module. +# +from torchrl.objectives import SoftUpdate + +updater = SoftUpdate(ddpg_loss, eps=0.99) + +################################### +# In your training loop, you will need to update the taget parameters at each +# optimization step or each collection step: + +updater.step() + +################################### +# This is all you need to know about loss modules to get started! +# +# To further explore the topic, have a look at: +# +# - The :ref:`loss module reference page `; +# - The :ref:`Coding a DDPG loss tutorial `; +# - Losses in action in :ref:`PPO ` or :ref:`DQN `. +# diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py index 3ea2637b469..329ce9585f1 100644 --- a/tutorials/sphinx-tutorials/getting-started-5.py +++ b/tutorials/sphinx-tutorials/getting-started-5.py @@ -93,11 +93,14 @@ # Loss module and optimizer # ------------------------- # +# We build our loss as indicated in the :ref:`dedicated tutorial `, with +# its optimizer and target parameter updater: -from torchrl.objectives import DQNLoss +from torchrl.objectives import DQNLoss, SoftUpdate -loss = DQNLoss(value_network=policy, action_space=env.action_spec) +loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True) optim = Adam(loss.parameters(), lr=0.02) +updater = SoftUpdate(loss, eps=0.99) ################################# # Logger @@ -109,7 +112,7 @@ from torchrl._utils import logger as torchrl_logger from torchrl.record import CSVLogger, VideoRecorder -path = pathlib.Path(__file__).parent / "training_loop" +path = "./training_loop" logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4") video_recorder = VideoRecorder(logger, tag="video") record_env = TransformedEnv( @@ -144,6 +147,7 @@ torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}") total_count += data.numel() total_episodes += data["next", "done"].sum() + updater.step() if max_length > 200: break