From e5a358b596ca2653455d36a4ebca95401e717054 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 17:23:37 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- sota-implementations/dqn/dqn_atari.py | 9 +++++---- sota-implementations/dqn/dqn_cartpole.py | 9 ++++++--- sota-implementations/dqn/utils_atari.py | 2 +- sota-implementations/dqn/utils_cartpole.py | 11 ++++++----- torchrl/data/tensor_specs.py | 12 ++++++++++-- torchrl/modules/tensordict_module/exploration.py | 3 ++- 6 files changed, 30 insertions(+), 16 deletions(-) diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 0e07462220f..255b6b2ee65 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -156,7 +156,9 @@ def update(sampled_tensordict): storing_device=device, max_frames_per_traj=-1, init_random_frames=init_random_frames, - compile_policy={"mode": compile_mode} if compile_mode is not None else False, + compile_policy={"mode": compile_mode, "fullgraph": True} + if compile_mode is not None + else False, cudagraph_policy=cfg.compile.cudagraphs, ) @@ -212,9 +214,8 @@ def update(sampled_tensordict): # Get and log q-values, loss, epsilon, sampling time and training time log_info.update( { - "train/q_values": (data["action_value"] * data["action"]).sum().item() - / frames_per_batch, - "train/q_loss": q_losses.mean().item(), + "train/q_values": data["chosen_action_value"].sum() / frames_per_batch, + "train/q_loss": q_losses.mean(), "train/epsilon": greedy_module.eps, } ) diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index e51a538d882..89a1e04d586 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -36,18 +36,19 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(device) # Make the components - model = make_dqn_model(cfg.env.env_name) + model = make_dqn_model(cfg.env.env_name, device=device) greedy_module = EGreedyModule( annealing_num_steps=cfg.collector.annealing_frames, eps_init=cfg.collector.eps_start, eps_end=cfg.collector.eps_end, spec=model.spec, + device=device, ) model_explore = TensorDictSequential( model, greedy_module, - ).to(device) + ) # Create the replay buffer replay_buffer = TensorDictReplayBuffer( @@ -135,7 +136,9 @@ def update(sampled_tensordict): storing_device="cpu", max_frames_per_traj=-1, init_random_frames=cfg.collector.init_random_frames, - compile_policy={"mode": compile_mode} if compile_mode is not None else False, + compile_policy={"mode": compile_mode, "fullgraph": True} + if compile_mode is not None + else False, cudagraph_policy=cfg.compile.cudagraphs, ) diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index 9b10b63ef6d..0956dfeb2ac 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -88,7 +88,7 @@ def make_dqn_modules_pixels(proof_environment, device): ) qvalue_module = QValueActor( module=torch.nn.Sequential(cnn, mlp), - spec=Composite(action=action_spec), + spec=Composite(action=action_spec).to(device), in_keys=["pixels"], ) return qvalue_module diff --git a/sota-implementations/dqn/utils_cartpole.py b/sota-implementations/dqn/utils_cartpole.py index d378f1ec76b..c49ff15f5fc 100644 --- a/sota-implementations/dqn/utils_cartpole.py +++ b/sota-implementations/dqn/utils_cartpole.py @@ -31,7 +31,7 @@ def make_env(env_name="CartPole-v1", device="cpu", from_pixels=False): # -------------------------------------------------------------------- -def make_dqn_modules(proof_environment): +def make_dqn_modules(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -45,19 +45,20 @@ def make_dqn_modules(proof_environment): activation_class=torch.nn.ReLU, out_features=num_outputs, num_cells=[120, 84], + device=device, ) qvalue_module = QValueActor( module=mlp, - spec=Composite(action=action_spec), + spec=Composite(action=action_spec).to(device), in_keys=["observation"], ) return qvalue_module -def make_dqn_model(env_name): - proof_environment = make_env(env_name, device="cpu") - qvalue_module = make_dqn_modules(proof_environment) +def make_dqn_model(env_name, device): + proof_environment = make_env(env_name, device=device) + qvalue_module = make_dqn_modules(proof_environment, device=device) del proof_environment return qvalue_module diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 4563fd3ca21..1898e679717 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -462,6 +462,11 @@ class CategoricalBox(Box): n: int register = invertible_dict() + def __post_init__(self): + # n could be a numpy array or a tensor, making compile go a bit crazy + # We want to make sure we're working with a regular integer + self.__dict__["n"] = int(self.n) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CategoricalBox: return deepcopy(self) @@ -3313,7 +3318,10 @@ def __init__( ) self.update_mask(mask) self._provisional_n = None - self._undefined_n = self.space.n < 0 + + @torch.compiler.assume_constant_result + def _undefined_n(self): + return self.space.n == -1 def enumerate(self) -> torch.Tensor: dtype = self.dtype @@ -3380,7 +3388,7 @@ def set_provisional_n(self, n: int): self._provisional_n = n def rand(self, shape: torch.Size = None) -> torch.Tensor: - if self._undefined_n: + if self._undefined_n(): if self._provisional_n is None: raise RuntimeError( "Cannot generate random categorical samples for undefined cardinality (n=-1). " diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 05c67b40c3f..6e8296a677a 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -153,7 +153,8 @@ def step(self, frames: int = 1) -> None: ) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: + expl = exploration_type() + if expl in (ExplorationType.RANDOM, None): if isinstance(self.action_key, tuple) and len(self.action_key) > 1: action_tensordict = tensordict.get(self.action_key[:-1]) action_key = self.action_key[-1]