diff --git a/sota-implementations/ddpg/config.yaml b/sota-implementations/ddpg/config.yaml index 43cb5093c09..ec90e59787f 100644 --- a/sota-implementations/ddpg/config.yaml +++ b/sota-implementations/ddpg/config.yaml @@ -13,7 +13,7 @@ collector: frames_per_batch: 1000 init_env_steps: 1000 reset_at_each_iter: False - device: cpu + device: env_per_collector: 1 @@ -39,6 +39,9 @@ network: hidden_sizes: [256, 256] activation: relu noise_type: "ou" # ou or gaussian + compile: False + compile_mode: + cudagraphs: False # logging logger: diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index cebc3685625..d18a547fcd6 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -10,7 +10,7 @@ The helper functions are coded in the utils.py associated with this script. """ -import time +import warnings import hydra @@ -18,9 +18,13 @@ import torch import torch.cuda import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( dump_video, @@ -44,6 +48,14 @@ def main(cfg: "DictConfig"): # noqa: F821 device = "cpu" device = torch.device(device) + collector_device = cfg.collector.device + if collector_device in ("", None): + if torch.cuda.is_available(): + collector_device = "cuda:0" + else: + collector_device = "cpu" + collector_device = torch.device(collector_device) + # Create logger exp_name = generate_exp_name("DDPG", cfg.logger.exp_name) logger = None @@ -73,8 +85,25 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create DDPG loss loss_module, target_net_updater = make_loss_module(cfg, model) + compile_mode = None + if cfg.network.compile: + if cfg.network.compile_mode not in (None, ""): + compile_mode = cfg.network.compile_mode + elif cfg.network.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create off-policy collector - collector = make_collector(cfg, train_env, exploration_policy) + collector = make_collector( + cfg, + train_env, + exploration_policy, + compile=cfg.network.compile, + compile_mode=compile_mode, + cudagraph=cfg.network.cudagraphs, + device=collector_device, + ) # Create replay buffer replay_buffer = make_replay_buffer( @@ -87,9 +116,29 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizers optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module) + optimizer = group_optimizers(optimizer_actor, optimizer_critic) + + def update(sampled_tensordict): + optimizer.zero_grad(set_to_none=True) + + td_loss: TensorDict = loss_module(sampled_tensordict) + td_loss.sum(reduce=True).backward() + optimizer.step() + + # Update qnet_target params + target_net_updater.step() + return td_loss.detach() + + if cfg.network.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.network.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -104,63 +153,43 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter eval_rollout_steps = cfg.env.max_episode_steps - sampling_start = time.time() - for _, tensordict in enumerate(collector): - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + tensordict = next(c_iter) # Update exploration policy exploration_policy[1].step(tensordict.numel()) # Update weights of the inference policy collector.update_policy_weights_() - pbar.update(tensordict.numel()) - - tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() + pbar.update(current_frames) + # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("rb - extend"): + tensordict = tensordict.reshape(-1) + replay_buffer.extend(tensordict) + collected_frames += current_frames # Optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - ( - actor_losses, - q_losses, - ) = ([], []) + tds = [] for _ in range(num_updates): # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict.clone() - - # Update critic - q_loss, *_ = loss_module.loss_value(sampled_tensordict) - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - - # Update actor - actor_loss, *_ = loss_module.loss_actor(sampled_tensordict) - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - q_losses.append(q_loss.item()) - actor_losses.append(actor_loss.item()) - - # Update qnet_target params - target_net_updater.step() + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample().to(device) + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + td_loss = update(sampled_tensordict) + tds.append(td_loss.clone()) # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) + tds = torch.stack(tds) - training_time = time.time() - training_start episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() @@ -178,15 +207,14 @@ def main(cfg: "DictConfig"): # noqa: F821 ) if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses) - metrics_to_log["train/a_loss"] = np.mean(actor_losses) - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + tds = TensorDict(train=tds).flatten_keys("/").mean() + metrics_to_log.update(tds.to_dict()) # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, exploration_policy, @@ -194,22 +222,21 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + if i % 20 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() + if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/ddpg/utils.py b/sota-implementations/ddpg/utils.py index 338081a7e8d..b94dc64ecfb 100644 --- a/sota-implementations/ddpg/utils.py +++ b/sota-implementations/ddpg/utils.py @@ -6,7 +6,7 @@ import torch -from tensordict.nn import TensorDictSequential +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn, optim from torchrl.collectors import SyncDataCollector @@ -30,8 +30,6 @@ AdditiveGaussianModule, MLP, OrnsteinUhlenbeckProcessModule, - SafeModule, - SafeSequential, TanhModule, ValueOperator, ) @@ -113,7 +111,15 @@ def make_environment(cfg, logger): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector( + cfg, + train_env, + actor_model_explore, + compile=False, + compile_mode=None, + cudagraph=False, + device: torch.device|None=None, +): """Make collector.""" collector = SyncDataCollector( train_env, @@ -122,7 +128,9 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, + compile_policy={"mode": compile_mode, "fullgraph": True} if compile else False, + cudagraph_policy=cudagraph, ) collector.set_seed(cfg.env.seed) return collector @@ -172,9 +180,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device): """Make DDPG agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.action_spec - if train_env.batch_size: - action_spec = action_spec[(0,) * len(train_env.batch_size)] + action_spec = train_env.single_action_spec actor_net_kwargs = { "num_cells": cfg.network.hidden_sizes, "out_features": action_spec.shape[-1], @@ -184,19 +190,16 @@ def make_ddpg_agent(cfg, train_env, eval_env, device): actor_net = MLP(**actor_net_kwargs) in_keys_actor = in_keys - actor_module = SafeModule( + actor_module = TensorDictModule( actor_net, in_keys=in_keys_actor, - out_keys=[ - "param", - ], + out_keys=["param"], ) - actor = SafeSequential( + actor = TensorDictSequential( actor_module, TanhModule( in_keys=["param"], out_keys=["action"], - spec=action_spec, ), ) @@ -234,6 +237,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device): OrnsteinUhlenbeckProcessModule( spec=action_spec, annealing_num_steps=1_000_000, + safe=False, ).to(device), ) elif cfg.network.noise_type == "gaussian": @@ -245,6 +249,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device): sigma_init=1.0, mean=0.0, std=0.1, + safe=False, ).to(device), ) else: diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 20128e4f6a2..345335cf9d2 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -66,6 +66,12 @@ RandomPolicy, set_exploration_type, ) +try: + from torch.compiler import cudagraph_mark_step_begin +except ImportError: + def cudagraph_mark_step_begin(): + """Placeholder when cudagraph_mark_step_begin is missing.""" + ... _TIMEOUT = 1.0 INSTANTIATE_TIMEOUT = 20 @@ -1145,7 +1151,11 @@ def rollout(self) -> TensorDictBase: else: policy_input = self._shuttle # we still do the assignment for security + if self.cudagraphed_policy: + cudagraph_mark_step_begin() policy_output = self.policy(policy_input) + if self.cudagraphed_policy: + policy_output = policy_output.clone() if self._shuttle is not policy_output: # ad-hoc update shuttle self._shuttle.update( diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 3590d76d2ce..3329158af05 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1392,6 +1392,7 @@ def type_check(self, value: torch.Tensor, key: NestedKey | None = None) -> None: spec.type_check(val) def is_in(self, value) -> bool: + raise RuntimeError if self.dim == 0 and not hasattr(value, "unbind"): # We don't use unbind because value could be a tuple or a nested tensor return all( @@ -1796,6 +1797,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: + raise RuntimeError if self.mask is None: shape = torch.broadcast_shapes(self._safe_shape, val.shape) shape_match = val.shape == shape @@ -2246,6 +2248,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: + raise RuntimeError val_shape = _remove_neg_shapes(tensordict.utils._shape(val)) shape = torch.broadcast_shapes(self._safe_shape, val_shape) shape = list(shape) @@ -2443,6 +2446,7 @@ def one(self, shape=None): ) def is_in(self, val: torch.Tensor) -> bool: + raise RuntimeError shape = torch.broadcast_shapes(self._safe_shape, val.shape) return ( isinstance(val, NonTensorData) @@ -2635,6 +2639,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: + raise RuntimeError shape = torch.broadcast_shapes(self._safe_shape, val.shape) return val.shape == shape and val.dtype == self.dtype @@ -2983,6 +2988,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten return torch.cat(out, -1) def is_in(self, val: torch.Tensor) -> bool: + raise RuntimeError vals = self._split(val) if vals is None: return False @@ -3328,6 +3334,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: + raise RuntimeError if self.mask is None: shape = torch.broadcast_shapes(self._safe_shape, val.shape) shape_match = val.shape == shape @@ -3953,6 +3960,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val.squeeze(0) if val_is_scalar else val def is_in(self, val: torch.Tensor) -> bool: + raise RuntimeError if self.mask is not None: vals = val.unbind(-1) splits = self._split_self() diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 4018589bfa1..f722bc2bd7d 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -69,7 +69,8 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): keys = [out_key] values = [spec] else: - keys = list(spec.keys(True, True)) + # Make dynamo happy with the list creation + keys = [key for key in spec.keys(True, True)] # noqa: C416 values = [spec[key] for key in keys] for _spec, _key in zip(values, keys): if _spec is None: diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index f04e0c78382..2f02d319086 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -133,11 +133,14 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): - self.eps.data[0] = max( - self.eps_end.item(), - ( - self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps - ).item(), + self.eps.data.copy_( + torch.maximum( + self.eps_end, + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ), + ) ) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -355,19 +358,20 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): - self.sigma.data[0] = max( - self.sigma_end.item(), - ( - self.sigma - - (self.sigma_init - self.sigma_end) / self.annealing_num_steps - ).item(), + self.sigma.data.copy_( + torch.maximum( + self.sigma_end( + self.sigma + - (self.sigma_init - self.sigma_end) / self.annealing_num_steps + ), + ) ) def _add_noise(self, action: torch.Tensor) -> torch.Tensor: - sigma = self.sigma.item() + sigma = self.sigma noise = torch.normal( - mean=torch.ones(action.shape) * self.mean.item(), - std=torch.ones(action.shape) * self.std.item(), + mean=torch.ones(action.shape) * self.mean, + std=torch.ones(action.shape) * self.std, ).to(action.device) action = action + noise * sigma spec = self.spec @@ -413,6 +417,9 @@ class AdditiveGaussianModule(TensorDictModuleBase): its output spec will be of type Composite. One needs to know where to find the action spec. default: "action" + safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space + given the :obj:`TensorSpec.project` heuristic. + default: True .. note:: It is @@ -434,6 +441,7 @@ def __init__( std: float = 1.0, *, action_key: Optional[NestedKey] = "action", + safe: bool = True, ): if not isinstance(sigma_init, float): warnings.warn("eps_init should be a float.") @@ -458,7 +466,9 @@ def __init__( else: raise RuntimeError("spec cannot be None.") self._spec = spec - self.register_forward_hook(_forward_hook_safe_action) + self.safe = safe + if self.safe: + self.register_forward_hook(_forward_hook_safe_action) @property def spec(self): @@ -474,19 +484,21 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): - self.sigma.data[0] = max( - self.sigma_end.item(), - ( - self.sigma - - (self.sigma_init - self.sigma_end) / self.annealing_num_steps - ).item(), + self.sigma.data.copy_( + torch.maximum( + self.sigma_end, + ( + self.sigma + - (self.sigma_init - self.sigma_end) / self.annealing_num_steps + ), + ) ) def _add_noise(self, action: torch.Tensor) -> torch.Tensor: - sigma = self.sigma.item() + sigma = self.sigma noise = torch.normal( - mean=torch.ones(action.shape) * self.mean.item(), - std=torch.ones(action.shape) * self.std.item(), + mean=torch.ones(action.shape) * self.mean, + std=torch.ones(action.shape) * self.std, ).to(action.device) action = action + noise * sigma spec = self.spec[self.action_key] @@ -684,12 +696,14 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): if self.annealing_num_steps > 0: - self.eps.data[0] = max( - self.eps_end.item(), - ( - self.eps - - (self.eps_init - self.eps_end) / self.annealing_num_steps - ).item(), + self.eps.data.copy_( + torch.maximum( + self.eps_end, + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ), + ) ) else: raise ValueError( @@ -712,9 +726,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker " f"transform to your environment with `env = TransformedEnv(env, InitTracker())`." ) - tensordict = self.ou.add_sample( - tensordict, self.eps.item(), is_init=is_init - ) + tensordict = self.ou.add_sample(tensordict, self.eps, is_init=is_init) return tensordict @@ -778,6 +790,10 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase): default: "action" is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps. default: "is_init" + safe (boolean, optional): if False, the TensorSpec can be None. If it + is set to False but the spec is passed, the projection will still + happen. + Default is True. Examples: >>> import torch @@ -820,6 +836,7 @@ def __init__( *, action_key: Optional[NestedKey] = "action", is_init_key: Optional[NestedKey] = "is_init", + safe: bool = True, ): super().__init__() @@ -863,7 +880,9 @@ def __init__( self._spec.update(ou_specs) if len(set(self.out_keys)) != len(self.out_keys): raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}") - self.register_forward_hook(_forward_hook_safe_action) + self.safe = safe + if self.safe: + self.register_forward_hook(_forward_hook_safe_action) @property def spec(self): @@ -878,12 +897,14 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): if self.annealing_num_steps > 0: - self.eps.data[0] = max( - self.eps_end.item(), - ( - self.eps - - (self.eps_init - self.eps_end) / self.annealing_num_steps - ).item(), + self.eps.data.copy_( + torch.maximum( + self.eps_end, + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ), + ) ) else: raise ValueError( @@ -905,9 +926,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker " f"transform to your environment with `env = TransformedEnv(env, InitTracker())`." ) - tensordict = self.ou.add_sample( - tensordict, self.eps.item(), is_init=is_init - ) + tensordict = self.ou.add_sample(tensordict, self.eps, is_init=is_init) return tensordict @@ -971,11 +990,12 @@ def _make_noise_pair( tensordict.set(self.noise_key, noise) tensordict.set(self.steps_key, steps) else: - noise = tensordict.get(self.noise_key) - steps = tensordict.get(self.steps_key) + # We must clone for cudagraph, otherwise the same tensor may re-enter the compiled region + noise = tensordict.get(self.noise_key).clone() + steps = tensordict.get(self.steps_key).clone() if is_init is not None: - noise[is_init] = 0 - steps[is_init] = 0 + noise = torch.masked_fill(noise, is_init, 0) + steps = torch.masked_fill(steps, is_init, 0) return noise, steps def add_sample( @@ -1025,9 +1045,9 @@ def add_sample( * np.sqrt(self.dt) * torch.randn_like(prev_noise) ) - tensordict.set_(self.noise_key, noise - self.x0) - tensordict.set_(self.key, tensordict.get(self.key) + eps * noise) - tensordict.set_(self.steps_key, n_steps + 1) + tensordict.set(self.noise_key, noise - self.x0) + tensordict.set(self.key, tensordict.get(self.key) + eps * noise) + tensordict.set(self.steps_key, n_steps + 1) return tensordict def current_sigma(self, n_steps: torch.Tensor) -> torch.Tensor: