From f614536320df2eabf76a12b377a02e1ac97bf91d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 15:18:38 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/dqn/dqn_atari.py | 5 +++-- sota-implementations/dqn/utils_atari.py | 12 +++++++----- .../modules/tensordict_module/exploration.py | 17 +++++++++++++---- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index d8e5047e8b3..0e07462220f 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -48,17 +48,18 @@ def main(cfg: "DictConfig"): # noqa: F821 test_interval = cfg.logger.test_interval // frame_skip # Make the components - model = make_dqn_model(cfg.env.env_name, frame_skip) + model = make_dqn_model(cfg.env.env_name, frame_skip, device=device) greedy_module = EGreedyModule( annealing_num_steps=cfg.collector.annealing_frames, eps_init=cfg.collector.eps_start, eps_end=cfg.collector.eps_end, spec=model.spec, + device=device, ) model_explore = TensorDictSequential( model, greedy_module, - ).to(device) + ) # Create the replay buffer if cfg.buffer.scratch_dir is None: diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index 1e5440a54b6..a135b78803d 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -61,7 +61,7 @@ def make_env(env_name, frame_skip, device, is_test=False): # -------------------------------------------------------------------- -def make_dqn_modules_pixels(proof_environment): +def make_dqn_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape @@ -75,13 +75,15 @@ def make_dqn_modules_pixels(proof_environment): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], + device=device, ) - cnn_output = cnn(torch.ones(input_shape)) + cnn_output = cnn(torch.ones(input_shape, device=device)) mlp = MLP( in_features=cnn_output.shape[-1], activation_class=torch.nn.ReLU, out_features=num_actions, num_cells=[512], + device=device, ) qvalue_module = QValueActor( module=torch.nn.Sequential(cnn, mlp), @@ -91,9 +93,9 @@ def make_dqn_modules_pixels(proof_environment): return qvalue_module -def make_dqn_model(env_name, frame_skip): - proof_environment = make_env(env_name, frame_skip, device="cpu") - qvalue_module = make_dqn_modules_pixels(proof_environment) +def make_dqn_model(env_name, frame_skip, device): + proof_environment = make_env(env_name, frame_skip, device=device) + qvalue_module = make_dqn_modules_pixels(proof_environment, device=device) del proof_environment return qvalue_module diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index a1879519271..05c67b40c3f 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -55,6 +55,7 @@ class EGreedyModule(TensorDictModuleBase): Default is ``"action"``. action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict. Default is ``None`` (corresponding to no mask). + device (torch.device, optional): the device of the exploration module. .. note:: It is crucial to incorporate a call to :meth:`~.step` in the training loop @@ -97,6 +98,7 @@ def __init__( *, action_key: Optional[NestedKey] = "action", action_mask_key: Optional[NestedKey] = None, + device: torch.device | None = None, ): if not isinstance(eps_init, float): warnings.warn("eps_init should be a float.") @@ -112,14 +114,18 @@ def __init__( super().__init__() - self.register_buffer("eps_init", torch.as_tensor(eps_init)) - self.register_buffer("eps_end", torch.as_tensor(eps_end)) + self.register_buffer("eps_init", torch.as_tensor(eps_init, device=device)) + self.register_buffer("eps_end", torch.as_tensor(eps_end, device=device)) self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.as_tensor(eps_init, dtype=torch.float32)) + self.register_buffer( + "eps", torch.as_tensor(eps_init, dtype=torch.float32, device=device) + ) if spec is not None: if not isinstance(spec, Composite) and len(self.out_keys) >= 1: spec = Composite({action_key: spec}, shape=spec.shape[:-1]) + if device is not None: + spec = spec.to(device) self._spec = spec @property @@ -183,7 +189,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"Action mask key {self.action_mask_key} not found in {tensordict}." ) spec.update_mask(action_mask) - out = torch.where(cond, spec.rand().to(out.device), out) + r = spec.rand() + if r.device != out.device: + r = r.to(out.device) + out = torch.where(cond, r, out) else: raise RuntimeError("spec must be provided to the exploration wrapper.") action_tensordict.set(action_key, out)