Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Fix tutos #1863

Merged
merged 9 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,8 @@
generate_tutorial_references("../../tutorials/sphinx-tutorials/", "tutorial")
# generate_tutorial_references("../../tutorials/src/", "src")
generate_tutorial_references("../../tutorials/media/", "media")

# We do this to indicate that the script is run by sphinx
import builtins

builtins.__sphinx_build__ = True
210 changes: 108 additions & 102 deletions tutorials/sphinx-tutorials/coding_ddpg.py

Large diffs are not rendered by default.

51 changes: 38 additions & 13 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
import tempfile
import warnings

from tensordict.nn import TensorDictSequential

warnings.filterwarnings("ignore")

from torch import multiprocessing
Expand All @@ -94,13 +96,17 @@
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
is_sphinx = __sphinx_build__
except NameError:
is_sphinx = False

try:
multiprocessing.set_start_method("spawn" if is_sphinx else "fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"
pass


# sphinx_gallery_end_ignore

import os
import uuid

Expand All @@ -125,7 +131,7 @@
ToTensorImage,
TransformedEnv,
)
from torchrl.modules import DuelingCnnDQNet, EGreedyWrapper, QValueActor
from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor

from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
Expand Down Expand Up @@ -270,6 +276,7 @@ def get_norm_stats():
# let's check that normalizing constants have a size of ``[C, 1, 1]`` where
# ``C=4`` (because of :class:`~torchrl.envs.CatFrames`).
print("state dict of the observation norm:", obs_norm_sd)
test_env.close()
return obs_norm_sd


Expand Down Expand Up @@ -328,13 +335,14 @@ def make_model(dummy_env):
tensordict = dummy_env.fake_tensordict()
actor(tensordict)

# we wrap our actor in an EGreedyWrapper for data collection
actor_explore = EGreedyWrapper(
actor,
# we join our actor with an EGreedyModule for data collection
exploration_module = EGreedyModule(
spec=dummy_env.action_spec,
annealing_num_steps=total_frames,
eps_init=eps_greedy_val,
eps_end=eps_greedy_val_env,
)
actor_explore = TensorDictSequential(actor, exploration_module)

return actor, actor_explore

Expand Down Expand Up @@ -381,6 +389,13 @@ def get_replay_buffer(buffer_size, n_optim, batch_size):
# We choose the following configuration: we will be running a series of
# parallel environments synchronously in parallel in different collectors,
# themselves running in parallel but asynchronously.
#
# .. note::
# This feature is only available when running the code within the "spawn"
# start method of python multiprocessing library. If this tutorial is run
# directly as a script (thereby using the "fork" method) we will be using
# a regular :class:`~torchrl.collectors.SyncDataCollector`.
#
# The advantage of this configuration is that we can balance the amount of
# compute that is executed in batch with what we want to be executed
# asynchronously. We encourage the reader to experiment how the collection
Expand Down Expand Up @@ -409,11 +424,10 @@ def get_collector(
total_frames,
device,
):
data_collector = MultiaSyncDataCollector(
[
make_env(parallel=True, obs_norm_sd=stats),
]
* num_collectors,
cls = MultiaSyncDataCollector
env_arg = [make_env(parallel=True, obs_norm_sd=stats)] * num_collectors
data_collector = cls(
env_arg,
policy=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
Expand Down Expand Up @@ -464,7 +478,12 @@ def get_loss_module(actor, gamma):
# in practice, and the performance of the algorithm should hopefully not be
# too sensitive to slight variations of these.

device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu"
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)

###############################################################################
# Optimizer
Expand Down Expand Up @@ -642,6 +661,12 @@ def get_loss_module(actor, gamma):
)
recorder.register(trainer)

###############################################################################
# The exploration module epsilon factor is also annealed:
#

trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch)

###############################################################################
# - Any callable (including :class:`~torchrl.trainers.TrainerHookBase`
# subclasses) can be registered using :meth:`~torchrl.trainers.Trainer.register_op`.
Expand Down
Loading
Loading