diff --git a/sota-implementations/sac/config.yaml b/sota-implementations/sac/config.yaml index 5cf531a3be2..587f575be0d 100644 --- a/sota-implementations/sac/config.yaml +++ b/sota-implementations/sac/config.yaml @@ -51,3 +51,8 @@ logger: mode: online eval_iter: 25000 video: False + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index ee3e7d08df0..440282eb9f2 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -12,7 +12,7 @@ """ from __future__ import annotations -import time +import warnings import hydra @@ -21,8 +21,11 @@ import torch.cuda import tqdm from tensordict import TensorDict -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import compile_with_warmup, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -75,8 +78,19 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create SAC loss loss_module, target_net_updater = make_loss_module(cfg, model) + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create off-policy collector - collector = make_collector(cfg, train_env, exploration_policy) + collector = make_collector( + cfg, train_env, exploration_policy, compile_mode=compile_mode + ) # Create replay buffer replay_buffer = make_replay_buffer( @@ -84,7 +98,7 @@ def main(cfg: "DictConfig"): # noqa: F821 prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, scratch_dir=cfg.replay_buffer.scratch_dir, - device="cpu", + device=device, ) # Create optimizers @@ -93,9 +107,36 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_critic, optimizer_alpha, ) = make_sac_optimizer(cfg, loss_module) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha) + del optimizer_actor, optimizer_critic, optimizer_alpha + + def update(sampled_tensordict): + # Compute loss + loss_td = loss_module(sampled_tensordict) + + actor_loss = loss_td["loss_actor"] + q_loss = loss_td["loss_qvalue"] + alpha_loss = loss_td["loss_alpha"] + + (actor_loss + q_loss + alpha_loss).sum().backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + # Update qnet_target params + target_net_updater.step() + return loss_td.detach() + + if cfg.compile.compile: + update = compile_with_warmup(update, mode=compile_mode, warmup=1) + + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -110,69 +151,48 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch = cfg.collector.frames_per_batch eval_rollout_steps = cfg.env.max_episode_steps - sampling_start = time.time() - for i, tensordict in enumerate(collector): - sampling_time = time.time() - sampling_start + collector_iter = iter(collector) + total_iter = len(collector) + + for i in range(total_iter): + timeit.printevery(num_prints=1000, total_count=total_iter, erase=True) + + with timeit("collect"): + tensordict = next(collector_iter) # Update weights of the inference policy collector.update_policy_weights_() - pbar.update(tensordict.numel()) - - tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() - # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + pbar.update(current_frames) + + with timeit("rb - extend"): + tensordict = tensordict.reshape(-1) + # Add to replay buffer + replay_buffer.extend(tensordict) + collected_frames += current_frames # Optimization steps - training_start = time.time() - if collected_frames >= init_random_frames: - losses = TensorDict(batch_size=[num_updates]) - for i in range(num_updates): - # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True + with timeit("train"): + if collected_frames >= init_random_frames: + losses = TensorDict(batch_size=[num_updates]) + for i in range(num_updates): + with timeit("rb - sample"): + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss_td = update(sampled_tensordict).clone() + losses[i] = loss_td.select( + "loss_actor", "loss_qvalue", "loss_alpha" ) - else: - sampled_tensordict = sampled_tensordict.clone() - - # Compute loss - loss_td = loss_module(sampled_tensordict) - - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_qvalue"] - alpha_loss = loss_td["loss_alpha"] - - # Update actor - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - # Update critic - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - - # Update alpha - optimizer_alpha.zero_grad() - alpha_loss.backward() - optimizer_alpha.step() - - losses[i] = loss_td.select( - "loss_actor", "loss_qvalue", "loss_alpha" - ).detach() - - # Update qnet_target params - target_net_updater.step() - # Update priority - if prb: - replay_buffer.update_priority(sampled_tensordict) + # Update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() @@ -184,23 +204,23 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log = {} if len(episode_rewards) > 0: episode_length = tensordict["next", "step_count"][episode_end] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + metrics_to_log["train/reward"] = episode_rewards + metrics_to_log["train/episode_length"] = episode_length.sum() / len( episode_length ) if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = losses.get("loss_qvalue").mean().item() - metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item() - metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item() - metrics_to_log["train/alpha"] = loss_td["alpha"].item() - metrics_to_log["train/entropy"] = loss_td["entropy"].item() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + losses = losses.mean() + metrics_to_log["train/q_loss"] = losses.get("loss_qvalue") + metrics_to_log["train/actor_loss"] = losses.get("loss_actor") + metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha") + metrics_to_log["train/alpha"] = loss_td["alpha"] + metrics_to_log["train/entropy"] = loss_td["entropy"] # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], @@ -208,22 +228,17 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/sac/utils.py b/sota-implementations/sac/utils.py index 6d37f5ec3d8..8e834e9352d 100644 --- a/sota-implementations/sac/utils.py +++ b/sota-implementations/sac/utils.py @@ -105,7 +105,7 @@ def make_environment(cfg, logger=None): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector(cfg, train_env, actor_model_explore, compile_mode): """Make collector.""" device = cfg.collector.device if device in ("", None): @@ -120,6 +120,8 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, device=device, + compile_policy={"mode": compile_mode} if compile_mode else False, + cudagraph_policy=cfg.compile.cudagraphs, ) collector.set_seed(cfg.env.seed) return collector @@ -169,7 +171,7 @@ def make_sac_agent(cfg, train_env, eval_env, device): """Make SAC agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.action_spec_unbatched + action_spec = train_env.action_spec_unbatched.to(device) actor_net_kwargs = { "num_cells": cfg.network.hidden_sizes, "out_features": 2 * action_spec.shape[-1], @@ -188,7 +190,7 @@ def make_sac_agent(cfg, train_env, eval_env, device): actor_extractor = NormalParamExtractor( scale_mapping=f"biased_softplus_{cfg.network.default_policy_scale}", scale_lb=cfg.network.scale_lb, - ) + ).to(device) actor_net = nn.Sequential(actor_net, actor_extractor) in_keys_actor = in_keys @@ -211,14 +213,11 @@ def make_sac_agent(cfg, train_env, eval_env, device): ) # Define Critic Network - qvalue_net_kwargs = { - "num_cells": cfg.network.hidden_sizes, - "out_features": 1, - "activation_class": get_activation(cfg), - } - qvalue_net = MLP( - **qvalue_net_kwargs, + num_cells=cfg.network.hidden_sizes, + out_features=1, + activation_class=get_activation(cfg), + device=device, ) qvalue = ValueOperator( @@ -226,7 +225,7 @@ def make_sac_agent(cfg, train_env, eval_env, device): module=qvalue_net, ) - model = nn.ModuleList([actor, qvalue]).to(device) + model = nn.ModuleList([actor, qvalue]) # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):