Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 2, 2024
1 parent 8516540 commit af83e06
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 9 deletions.
6 changes: 5 additions & 1 deletion tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@
###############################################################################
# We will execute the policy on CUDA if available
is_fork = multiprocessing.get_start_method() == "fork"
device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu")
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA

###############################################################################
Expand Down
6 changes: 5 additions & 1 deletion tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,11 @@ def get_loss_module(actor, gamma):
# too sensitive to slight variations of these.

is_fork = multiprocessing.get_start_method() == "fork"
device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu")
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)

###############################################################################
# Optimizer
Expand Down
6 changes: 5 additions & 1 deletion tutorials/sphinx-tutorials/coding_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@
#

is_fork = multiprocessing.get_start_method() == "fork"
device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu")
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
num_cells = 256 # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0
Expand Down
6 changes: 5 additions & 1 deletion tutorials/sphinx-tutorials/dqn_with_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@
from torchrl.objectives import DQNLoss, SoftUpdate

is_fork = multiprocessing.get_start_method() == "fork"
device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu")
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)

######################################################################
# Environment
Expand Down
12 changes: 8 additions & 4 deletions tutorials/sphinx-tutorials/multiagent_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@
# Torch
import torch

# Tensordict modules
from torch import multiprocessing

from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

# Tensordict modules
from torch import multiprocessing

# Data collection
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
Expand Down Expand Up @@ -164,7 +164,11 @@

# Devices
is_fork = multiprocessing.get_start_method() == "fork"
device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu")
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
vmas_device = device # The device where the simulator is run (VMAS can run on GPU)

# Sampling
Expand Down
6 changes: 5 additions & 1 deletion tutorials/sphinx-tutorials/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
from torchrl.modules import Actor

is_fork = multiprocessing.get_start_method() == "fork"
device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu")
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)

##############################################################################
# Let us first create an environment. For the sake of simplicity, we will be using
Expand Down

0 comments on commit af83e06

Please sign in to comment.