diff --git a/sota-implementations/td3/config.yaml b/sota-implementations/td3/config.yaml index 5bdf22ea6fa..ba2db81b489 100644 --- a/sota-implementations/td3/config.yaml +++ b/sota-implementations/td3/config.yaml @@ -21,7 +21,7 @@ collector: replay_buffer: prb: 0 # use prioritized experience replay size: 1000000 - scratch_dir: null + scratch_dir: # optim optim: @@ -52,3 +52,8 @@ logger: mode: online eval_iter: 25000 video: False + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 70333f56cd9..36183655fe0 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -12,14 +12,16 @@ """ from __future__ import annotations -import time +import warnings import hydra import numpy as np import torch import torch.cuda import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import compile_with_warmup, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -36,6 +38,9 @@ ) +torch.set_float32_matmul_precision("high") + + @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 device = cfg.network.device @@ -75,8 +80,19 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create TD3 loss loss_module, target_net_updater = make_loss_module(cfg, model) + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create off-policy collector - collector = make_collector(cfg, train_env, exploration_policy) + collector = make_collector( + cfg, train_env, exploration_policy, compile_mode=compile_mode + ) # Create replay buffer replay_buffer = make_replay_buffer( @@ -84,14 +100,47 @@ def main(cfg: "DictConfig"): # noqa: F821 prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, scratch_dir=cfg.replay_buffer.scratch_dir, - device="cpu", + device=device, ) # Create optimizers optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module) + def update(sampled_tensordict, update_actor): + # Compute loss + q_loss, *_ = loss_module.value_loss(sampled_tensordict) + + # Update critic + q_loss.backward() + optimizer_critic.step() + optimizer_critic.zero_grad(set_to_none=True) + + # Update actor + if update_actor: + actor_loss, *_ = loss_module.actor_loss(sampled_tensordict) + + actor_loss.backward() + optimizer_actor.step() + optimizer_actor.zero_grad(set_to_none=True) + + # Update target params + target_net_updater.step() + else: + actor_loss = q_loss.new_zeros(()) + + return q_loss.detach(), actor_loss.detach() + + if cfg.compile.compile: + update = compile_with_warmup(update, mode=compile_mode, warmup=1) + + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -108,70 +157,56 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch = cfg.collector.frames_per_batch update_counter = 0 - sampling_start = time.time() - for tensordict in collector: - sampling_time = time.time() - sampling_start - exploration_policy[1].step(tensordict.numel()) + collector_iter = iter(collector) + total_iter = len(collector) + + for _ in range(total_iter): + timeit.printevery(num_prints=1000, total_count=total_iter, erase=True) + + with timeit("collect"): + tensordict = next(collector_iter) # Update weights of the inference policy collector.update_policy_weights_() - pbar.update(tensordict.numel()) - - tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() - # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + pbar.update(current_frames) + + with timeit("rb - extend"): + # Add to replay buffer + tensordict = tensordict.reshape(-1) + replay_buffer.extend(tensordict) + collected_frames += current_frames - # Optimization steps - training_start = time.time() - if collected_frames >= init_random_frames: - ( - actor_losses, - q_losses, - ) = ([], []) - for _ in range(num_updates): - - # Update actor every delayed_updates - update_counter += 1 - update_actor = update_counter % delayed_updates == 0 - - # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict.clone() - - # Compute loss - q_loss, *_ = loss_module.value_loss(sampled_tensordict) - - # Update critic - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - q_losses.append(q_loss.item()) - - # Update actor - if update_actor: - actor_loss, *_ = loss_module.actor_loss(sampled_tensordict) - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - actor_losses.append(actor_loss.item()) - - # Update target params - target_net_updater.step() - - # Update priority - if prb: - replay_buffer.update_priority(sampled_tensordict) - - training_time = time.time() - training_start + with timeit("train"): + # Optimization steps + if collected_frames >= init_random_frames: + ( + actor_losses, + q_losses, + ) = ([], []) + for _ in range(num_updates): + # Update actor every delayed_updates + update_counter += 1 + update_actor = update_counter % delayed_updates == 0 + + with timeit("rb - sample"): + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + q_loss, actor_loss = update(sampled_tensordict, update_actor) + + q_losses.append(q_loss.clone()) + if update_actor: + actor_losses.append(actor_loss.clone()) + + # Update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) + episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() @@ -183,22 +218,21 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log = {} if len(episode_rewards) > 0: episode_length = tensordict["next", "step_count"][episode_end] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + metrics_to_log["train/reward"] = episode_rewards.mean() + metrics_to_log["train/episode_length"] = episode_length.sum() / len( episode_length ) if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses) + metrics_to_log["train/q_loss"] = torch.stack(q_losses).mean() if update_actor: - metrics_to_log["train/a_loss"] = np.mean(actor_losses) - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + metrics_to_log["train/a_loss"] = torch.stack(actor_losses).mean() # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, exploration_policy, @@ -206,22 +240,17 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index df81a522b3c..7d27f34dea7 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -9,12 +9,12 @@ from contextlib import nullcontext import torch -from tensordict.nn import TensorDictSequential +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn, optim from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer -from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage from torchrl.envs import ( CatTensors, Compose, @@ -29,14 +29,7 @@ ) from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import ( - AdditiveGaussianModule, - MLP, - SafeModule, - SafeSequential, - TanhModule, - ValueOperator, -) +from torchrl.modules import AdditiveGaussianModule, MLP, TanhModule, ValueOperator from torchrl.objectives import SoftUpdate from torchrl.objectives.td3 import TD3Loss @@ -116,7 +109,7 @@ def make_environment(cfg, logger=None): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector(cfg, train_env, actor_model_explore, compile_mode): """Make collector.""" device = cfg.collector.device if device in ("", None): @@ -132,48 +125,52 @@ def make_collector(cfg, train_env, actor_model_explore): total_frames=cfg.collector.total_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, device=device, + compile_policy={"mode": compile_mode} if compile_mode else False, + cudagraph_policy=cfg.compile.cudagraphs, ) collector.set_seed(cfg.env.seed) return collector def make_replay_buffer( - batch_size, - prb=False, - buffer_size=1000000, - scratch_dir=None, - device="cpu", - prefetch=3, + batch_size: int, + prb: bool = False, + buffer_size: int = 1000000, + scratch_dir: str | None = None, + device: torch.device = "cpu", + prefetch: int = 3, ): with ( tempfile.TemporaryDirectory() if scratch_dir is None else nullcontext(scratch_dir) ) as scratch_dir: + storage_cls = ( + functools.partial(LazyTensorStorage, device=device) + if not scratch_dir + else functools.partial( + LazyMemmapStorage, device="cpu", scratch_dir=scratch_dir + ) + ) + if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( alpha=0.7, beta=0.5, pin_memory=False, prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=scratch_dir, - device=device, - ), + storage=storage_cls(buffer_size), batch_size=batch_size, ) else: replay_buffer = TensorDictReplayBuffer( pin_memory=False, prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=scratch_dir, - device=device, - ), + storage=storage_cls(buffer_size), batch_size=batch_size, ) + if scratch_dir: + replay_buffer.append_transform(lambda td: td.to(device)) return replay_buffer @@ -186,26 +183,21 @@ def make_td3_agent(cfg, train_env, eval_env, device): """Make TD3 agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.action_spec - if train_env.batch_size: - action_spec = action_spec[(0,) * len(train_env.batch_size)] - actor_net_kwargs = { - "num_cells": cfg.network.hidden_sizes, - "out_features": action_spec.shape[-1], - "activation_class": get_activation(cfg), - } - - actor_net = MLP(**actor_net_kwargs) + action_spec = train_env.action_spec_unbatched.to(device) + actor_net = MLP( + num_cells=cfg.network.hidden_sizes, + out_features=action_spec.shape[-1], + activation_class=get_activation(cfg), + device=device, + ) in_keys_actor = in_keys - actor_module = SafeModule( + actor_module = TensorDictModule( actor_net, in_keys=in_keys_actor, - out_keys=[ - "param", - ], + out_keys=["param"], ) - actor = SafeSequential( + actor = TensorDictSequential( actor_module, TanhModule( in_keys=["param"], @@ -215,14 +207,11 @@ def make_td3_agent(cfg, train_env, eval_env, device): ) # Define Critic Network - qvalue_net_kwargs = { - "num_cells": cfg.network.hidden_sizes, - "out_features": 1, - "activation_class": get_activation(cfg), - } - qvalue_net = MLP( - **qvalue_net_kwargs, + num_cells=cfg.network.hidden_sizes, + out_features=1, + activation_class=get_activation(cfg), + device=device, ) qvalue = ValueOperator( @@ -230,17 +219,14 @@ def make_td3_agent(cfg, train_env, eval_env, device): module=qvalue_net, ) - model = nn.ModuleList([actor, qvalue]).to(device) + model = nn.ModuleList([actor, qvalue]) # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = eval_env.reset() + td = eval_env.fake_tensordict() td = td.to(device) for net in model: net(td) - del td - eval_env.close() - # Exploration wrappers: actor_model_explore = TensorDictSequential( model[0], diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ad29b63db04..9b1275a7dde 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2274,10 +2274,13 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: return r def _project(self, val: torch.Tensor) -> torch.Tensor: - low = self.space.low.to(val.device) - high = self.space.high.to(val.device) + low = self.space.low + high = self.space.high + if self.device != val.device: + low = low.to(val.device) + high = high.to(val.device) try: - val = val.clamp_(low.item(), high.item()) + val = torch.maximum(torch.minimum(val, high), low) except ValueError: low = low.expand_as(val) high = high.expand_as(val)