Skip to content

Commit

Permalink
[Doc] Fix tutos (#1863)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 6, 2024
1 parent 2d87abe commit 0af176f
Show file tree
Hide file tree
Showing 12 changed files with 486 additions and 326 deletions.
5 changes: 5 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,8 @@
generate_tutorial_references("../../tutorials/sphinx-tutorials/", "tutorial")
# generate_tutorial_references("../../tutorials/src/", "src")
generate_tutorial_references("../../tutorials/media/", "media")

# We do this to indicate that the script is run by sphinx
import builtins

builtins.__sphinx_build__ = True
210 changes: 108 additions & 102 deletions tutorials/sphinx-tutorials/coding_ddpg.py

Large diffs are not rendered by default.

51 changes: 38 additions & 13 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
import tempfile
import warnings

from tensordict.nn import TensorDictSequential

warnings.filterwarnings("ignore")

from torch import multiprocessing
Expand All @@ -94,13 +96,17 @@
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
is_sphinx = __sphinx_build__
except NameError:
is_sphinx = False

try:
multiprocessing.set_start_method("spawn" if is_sphinx else "fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"
pass


# sphinx_gallery_end_ignore

import os
import uuid

Expand All @@ -125,7 +131,7 @@
ToTensorImage,
TransformedEnv,
)
from torchrl.modules import DuelingCnnDQNet, EGreedyWrapper, QValueActor
from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor

from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
Expand Down Expand Up @@ -270,6 +276,7 @@ def get_norm_stats():
# let's check that normalizing constants have a size of ``[C, 1, 1]`` where
# ``C=4`` (because of :class:`~torchrl.envs.CatFrames`).
print("state dict of the observation norm:", obs_norm_sd)
test_env.close()
return obs_norm_sd


Expand Down Expand Up @@ -328,13 +335,14 @@ def make_model(dummy_env):
tensordict = dummy_env.fake_tensordict()
actor(tensordict)

# we wrap our actor in an EGreedyWrapper for data collection
actor_explore = EGreedyWrapper(
actor,
# we join our actor with an EGreedyModule for data collection
exploration_module = EGreedyModule(
spec=dummy_env.action_spec,
annealing_num_steps=total_frames,
eps_init=eps_greedy_val,
eps_end=eps_greedy_val_env,
)
actor_explore = TensorDictSequential(actor, exploration_module)

return actor, actor_explore

Expand Down Expand Up @@ -381,6 +389,13 @@ def get_replay_buffer(buffer_size, n_optim, batch_size):
# We choose the following configuration: we will be running a series of
# parallel environments synchronously in parallel in different collectors,
# themselves running in parallel but asynchronously.
#
# .. note::
# This feature is only available when running the code within the "spawn"
# start method of python multiprocessing library. If this tutorial is run
# directly as a script (thereby using the "fork" method) we will be using
# a regular :class:`~torchrl.collectors.SyncDataCollector`.
#
# The advantage of this configuration is that we can balance the amount of
# compute that is executed in batch with what we want to be executed
# asynchronously. We encourage the reader to experiment how the collection
Expand Down Expand Up @@ -409,11 +424,10 @@ def get_collector(
total_frames,
device,
):
data_collector = MultiaSyncDataCollector(
[
make_env(parallel=True, obs_norm_sd=stats),
]
* num_collectors,
cls = MultiaSyncDataCollector
env_arg = [make_env(parallel=True, obs_norm_sd=stats)] * num_collectors
data_collector = cls(
env_arg,
policy=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
Expand Down Expand Up @@ -464,7 +478,12 @@ def get_loss_module(actor, gamma):
# in practice, and the performance of the algorithm should hopefully not be
# too sensitive to slight variations of these.

device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu"
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)

###############################################################################
# Optimizer
Expand Down Expand Up @@ -642,6 +661,12 @@ def get_loss_module(actor, gamma):
)
recorder.register(trainer)

###############################################################################
# The exploration module epsilon factor is also annealed:
#

trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch)

###############################################################################
# - Any callable (including :class:`~torchrl.trainers.TrainerHookBase`
# subclasses) can be registered using :meth:`~torchrl.trainers.Trainer.register_op`.
Expand Down
Loading

0 comments on commit 0af176f

Please sign in to comment.