Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 14, 2024
2 parents 47c5422 + f614536 commit eae42a9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
5 changes: 3 additions & 2 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,18 @@ def main(cfg: "DictConfig"): # noqa: F821
test_interval = cfg.logger.test_interval // frame_skip

# Make the components
model = make_dqn_model(cfg.env.env_name, frame_skip)
model = make_dqn_model(cfg.env.env_name, frame_skip, device=device)
greedy_module = EGreedyModule(
annealing_num_steps=cfg.collector.annealing_frames,
eps_init=cfg.collector.eps_start,
eps_end=cfg.collector.eps_end,
spec=model.spec,
device=device,
)
model_explore = TensorDictSequential(
model,
greedy_module,
).to(device)
)

# Create the replay buffer
if cfg.buffer.scratch_dir is None:
Expand Down
12 changes: 7 additions & 5 deletions sota-implementations/dqn/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def make_env(env_name, frame_skip, device, is_test=False):
# --------------------------------------------------------------------


def make_dqn_modules_pixels(proof_environment):
def make_dqn_modules_pixels(proof_environment, device):

# Define input shape
input_shape = proof_environment.observation_spec["pixels"].shape
Expand All @@ -75,13 +75,15 @@ def make_dqn_modules_pixels(proof_environment):
num_cells=[32, 64, 64],
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
device=device,
)
cnn_output = cnn(torch.ones(input_shape))
cnn_output = cnn(torch.ones(input_shape, device=device))
mlp = MLP(
in_features=cnn_output.shape[-1],
activation_class=torch.nn.ReLU,
out_features=num_actions,
num_cells=[512],
device=device,
)
qvalue_module = QValueActor(
module=torch.nn.Sequential(cnn, mlp),
Expand All @@ -91,9 +93,9 @@ def make_dqn_modules_pixels(proof_environment):
return qvalue_module


def make_dqn_model(env_name, frame_skip):
proof_environment = make_env(env_name, frame_skip, device="cpu")
qvalue_module = make_dqn_modules_pixels(proof_environment)
def make_dqn_model(env_name, frame_skip, device):
proof_environment = make_env(env_name, frame_skip, device=device)
qvalue_module = make_dqn_modules_pixels(proof_environment, device=device)
del proof_environment
return qvalue_module

Expand Down
17 changes: 13 additions & 4 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class EGreedyModule(TensorDictModuleBase):
Default is ``"action"``.
action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict.
Default is ``None`` (corresponding to no mask).
device (torch.device, optional): the device of the exploration module.
.. note::
It is crucial to incorporate a call to :meth:`~.step` in the training loop
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
*,
action_key: Optional[NestedKey] = "action",
action_mask_key: Optional[NestedKey] = None,
device: torch.device | None = None,
):
if not isinstance(eps_init, float):
warnings.warn("eps_init should be a float.")
Expand All @@ -112,14 +114,18 @@ def __init__(

super().__init__()

self.register_buffer("eps_init", torch.as_tensor(eps_init))
self.register_buffer("eps_end", torch.as_tensor(eps_end))
self.register_buffer("eps_init", torch.as_tensor(eps_init, device=device))
self.register_buffer("eps_end", torch.as_tensor(eps_end, device=device))
self.annealing_num_steps = annealing_num_steps
self.register_buffer("eps", torch.as_tensor(eps_init, dtype=torch.float32))
self.register_buffer(
"eps", torch.as_tensor(eps_init, dtype=torch.float32, device=device)
)

if spec is not None:
if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
spec = Composite({action_key: spec}, shape=spec.shape[:-1])
if device is not None:
spec = spec.to(device)
self._spec = spec

@property
Expand Down Expand Up @@ -183,7 +189,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
f"Action mask key {self.action_mask_key} not found in {tensordict}."
)
spec.update_mask(action_mask)
out = torch.where(cond, spec.rand().to(out.device), out)
r = spec.rand()
if r.device != out.device:
r = r.to(out.device)
out = torch.where(cond, r, out)
else:
raise RuntimeError("spec must be provided to the exploration wrapper.")
action_tensordict.set(action_key, out)
Expand Down

0 comments on commit eae42a9

Please sign in to comment.