Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 9, 2024
1 parent 18e8b10 commit 3d7d9ee
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 12 deletions.
3 changes: 1 addition & 2 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,7 @@ def __init__(
):
if eps is None and tau is None:
raise RuntimeError(
"Neither eps nor tau was provided. " "This behaviour is deprecated.",
category=DeprecationWarning,
"Neither eps nor tau was provided. This behaviour is deprecated.",
)
eps = 0.999
if (eps is None) ^ (tau is None):
Expand Down
2 changes: 2 additions & 0 deletions tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
======================================
**Author**: `Vincent Moens <https://github.com/vmoens>`_
.. _coding_ddpg:
"""

##############################################################################
Expand Down
2 changes: 2 additions & 0 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
==============================
**Author**: `Vincent Moens <https://github.com/vmoens>`_
.. _coding_dqn:
"""

##############################################################################
Expand Down
2 changes: 2 additions & 0 deletions tutorials/sphinx-tutorials/coding_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
==================================================
**Author**: `Vincent Moens <https://github.com/vmoens>`_
.. _coding_ppo:
This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to train a parametric policy
network to solve the Inverted Pendulum task from the `OpenAI-Gym/Farama-Gymnasium
control library <https://github.com/Farama-Foundation/Gymnasium>`__.
Expand Down
4 changes: 3 additions & 1 deletion tutorials/sphinx-tutorials/getting-started-1.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@
# to :meth:`~torchrl.modules.EGreedyModule.step` is required (see the last
# :ref:`tutorial <gs_first_training>` for an example).
#
exploration_module = EGreedyModule(spec=env.action_spec, annealing_num_steps=1000, eps_init=0.5)
exploration_module = EGreedyModule(
spec=env.action_spec, annealing_num_steps=1000, eps_init=0.5
)

###################################
# To build our explorative policy, we only had to concatenate the
Expand Down
79 changes: 73 additions & 6 deletions tutorials/sphinx-tutorials/getting-started-2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,18 @@

env = GymEnv("Pendulum-v1")

from torchrl.objectives import DDPGLoss
from torchrl.modules import Actor, MLP, ValueOperator
from torchrl.objectives import DDPGLoss

n_obs = env.observation_spec['observation'].shape[-1]
n_obs = env.observation_spec["observation"].shape[-1]
n_act = env.action_spec.shape[-1]
actor = Actor(MLP(in_features=n_obs, out_features=n_act, num_cells=[32, 32]))
value_net = ValueOperator(MLP(in_features=n_obs+n_act, out_features=1, num_cells=[32, 32]), in_keys=["observation", "action"])
value_net = ValueOperator(
MLP(in_features=n_obs + n_act, out_features=1, num_cells=[32, 32]),
in_keys=["observation", "action"],
)

loss = DDPGLoss(actor_network=actor, value_network=value_net)
ddpg_loss = DDPGLoss(actor_network=actor, value_network=value_net)

###################################
# And that is it! Our loss module can now be run with data coming from the
Expand All @@ -65,15 +68,79 @@
#

rollout = env.rollout(max_steps=100, policy=actor)
loss_vals = loss(rollout)
loss_vals = ddpg_loss(rollout)
print(loss_vals)

###################################
# LossModule's output
# -------------------
#
# As you can see, the value we received from the loss isn't a single scalar
# but a dictionary containing multiple losses.
#
# The reason is simple: because more than one network may be trained at a time,
# and since some users may wish to separate the optimization of each module
# in distinct steps, TorchRL's objectives will return dictionaries containing
# the various loss components.
#
#
# This format also allows us to pass metadata along with the loss values. In
# general, we make sure that only the loss values are differentiable such that
# you can simply sum over the values of the dictionary to obtain the total
# loss. If you want to make sure you're fully in control of what is happening,
# you can sum over only the entries which keys start with the ``"loss_"`` prefix:
#
total_loss = 0
for key, val in loss_vals.items():
if key.startswith("loss_"):
total_loss += val

###################################
# Given all this, training the modules is not so different than what would be
# done in any other training loop. We'll need an optimizer (or one optimizer
# per module if that is your choice). The following items will typically be
# found in your training loop:

from torch.optim import Adam

optim = Adam(ddpg_loss.parameters())
total_loss.backward()
optim.step()
optim.zero_grad()

###################################
# Further considerations: Target parameters
# -----------------------------------------
#
# Another important consideration is that off-policy algorithms such as DDPG
# typically have target parameters associated with them. Target parameters are
# usually a version of the parameters that lags in time (or a smoothed
# average of that) and they are used for value estimation when training the
# policy. Training the policy using target parameters is usually much more
# efficient than using the configuraton of the value network parameters at the
# same time. You usually don't need to care too much about target parameters
# as your loss module will create them for you, **but** it is your
# responsibility to update these values when needed depending on your needs.
# TorchRL provides a couple of updaters, namely
# :class:`~torchrl.objectives.HardUpdate` and
# :class:`~torchrl.objectives.SoftUpdate`. Instantiating them is very easy and
# doesn't require any knowledge about the inner machinery of the loss module.
#
from torchrl.objectives import SoftUpdate

updater = SoftUpdate(ddpg_loss, eps=0.99)

###################################
# In your training loop, you will need to update the taget parameters at each
# optimization step or each collection step:

updater.step()

###################################
# This is all you need to know about loss modules to get started!
#
# To further explore the topic, have a look at:
#
# - The :ref:`loss module reference page <reference/objectives>`;
# - The :ref:`Coding a DDPG loss tutorial <coding_ddpg>`;
# - Losses in action in :ref:`PPO <coding_ppo>` or :ref:`DQN <coding_dqn>`.
#
10 changes: 7 additions & 3 deletions tutorials/sphinx-tutorials/getting-started-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,14 @@
# Loss module and optimizer
# -------------------------
#
# We build our loss as indicated in the :ref:`dedicated tutorial <gs_optim>`, with
# its optimizer and target parameter updater:

from torchrl.objectives import DQNLoss
from torchrl.objectives import DQNLoss, SoftUpdate

loss = DQNLoss(value_network=policy, action_space=env.action_spec)
loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=0.02)
updater = SoftUpdate(loss, eps=0.99)

#################################
# Logger
Expand All @@ -109,7 +112,7 @@
from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder

path = pathlib.Path(__file__).parent / "training_loop"
path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
Expand Down Expand Up @@ -144,6 +147,7 @@
torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
total_count += data.numel()
total_episodes += data["next", "done"].sum()
updater.step()
if max_length > 200:
break

Expand Down

0 comments on commit 3d7d9ee

Please sign in to comment.