diff --git a/sota-implementations/td3_bc/config.yaml b/sota-implementations/td3_bc/config.yaml index 54275a94bc2..1456f2f2acf 100644 --- a/sota-implementations/td3_bc/config.yaml +++ b/sota-implementations/td3_bc/config.yaml @@ -43,3 +43,8 @@ logger: eval_steps: 1000 eval_envs: 1 video: False + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index 75be949df90..35563777962 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -11,13 +11,16 @@ """ from __future__ import annotations -import time +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule + +from torchrl._utils import compile_with_warmup, timeit from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -72,7 +75,16 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create replay buffer - replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) + replay_buffer = make_offline_replay_buffer(cfg.replay_buffer, device=device) + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" # Create agent model, _ = make_td3_agent(cfg, eval_env, device) @@ -83,54 +95,73 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizer optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module) - gradient_steps = cfg.optim.gradient_steps - evaluation_interval = cfg.logger.eval_iter - eval_steps = cfg.logger.eval_steps - delayed_updates = cfg.optim.policy_update_delay - update_counter = 0 - pbar = tqdm.tqdm(range(gradient_steps)) - # Training loop - start_time = time.time() - for i in pbar: - pbar.update(1) - # Update actor every delayed_updates - update_counter += 1 - update_actor = update_counter % delayed_updates == 0 - - # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to(device) - else: - sampled_tensordict = sampled_tensordict.clone() - + def update(sampled_tensordict, update_actor): # Compute loss q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict) # Update critic - optimizer_critic.zero_grad() q_loss.backward() optimizer_critic.step() - q_loss.item() - - to_log = {"q_loss": q_loss.item()} + optimizer_critic.zero_grad(set_to_none=True) # Update actor if update_actor: actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict) - optimizer_actor.zero_grad() actor_loss.backward() optimizer_actor.step() + optimizer_actor.zero_grad(set_to_none=True) # Update target params target_net_updater.step() + else: + actorloss_metadata = {} + actor_loss = q_loss.new_zeros(()) + metadata = TensorDict(actorloss_metadata) + metadata.set("q_loss", q_loss.detach()) + metadata.set("actor_loss", actor_loss.detach()) + return metadata + + if cfg.compile.compile: + update = compile_with_warmup(update, mode=compile_mode, warmup=1) + + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + + gradient_steps = cfg.optim.gradient_steps + evaluation_interval = cfg.logger.eval_iter + eval_steps = cfg.logger.eval_steps + delayed_updates = cfg.optim.policy_update_delay + pbar = tqdm.tqdm(range(gradient_steps)) + # Training loop + for update_counter in pbar: + timeit.printevery(num_prints=1000, total_count=gradient_steps, erase=True) - to_log["actor_loss"] = actor_loss.item() - to_log.update(actorloss_metadata) + # Update actor every delayed_updates + update_actor = update_counter % delayed_updates == 0 + + with timeit("rb - sample"): + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + metadata = update(sampled_tensordict, update_actor).clone() + + to_log = {} + if update_actor: + to_log.update(metadata.to_dict()) + else: + to_log.update(metadata.exclude("actor_loss").to_dict()) # evaluation - if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + if update_counter % evaluation_interval == 0: + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) @@ -138,12 +169,12 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_reward = eval_td["next", "reward"].sum(1).mean().item() to_log["evaluation_reward"] = eval_reward if logger is not None: - log_metrics(logger, to_log, i) + to_log.update(timeit.todict(prefix="time")) + log_metrics(logger, to_log, update_counter) if not eval_env.is_closed: eval_env.close() pbar.close() - torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/sota-implementations/td3_bc/utils.py b/sota-implementations/td3_bc/utils.py index d0c3161861d..c7b99e4f0e3 100644 --- a/sota-implementations/td3_bc/utils.py +++ b/sota-implementations/td3_bc/utils.py @@ -7,7 +7,7 @@ import functools import torch -from tensordict.nn import TensorDictSequential +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn, optim from torchrl.data.datasets.d4rl import D4RLExperienceReplay @@ -26,14 +26,7 @@ ) from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import ( - AdditiveGaussianModule, - MLP, - SafeModule, - SafeSequential, - TanhModule, - ValueOperator, -) +from torchrl.modules import AdditiveGaussianModule, MLP, TanhModule, ValueOperator from torchrl.objectives import SoftUpdate from torchrl.objectives.td3_bc import TD3BCLoss @@ -98,17 +91,19 @@ def make_environment(cfg, logger=None): # --------------------------- -def make_offline_replay_buffer(rb_cfg): +def make_offline_replay_buffer(rb_cfg, device): data = D4RLExperienceReplay( dataset_id=rb_cfg.dataset, split_trajs=False, batch_size=rb_cfg.batch_size, - sampler=SamplerWithoutReplacement(drop_last=False), + # drop_last for compile + sampler=SamplerWithoutReplacement(drop_last=True), prefetch=4, direct_download=True, ) data.append_transform(DoubleToFloat()) + data.append_transform(lambda td: td.to(device)) return data @@ -122,26 +117,22 @@ def make_td3_agent(cfg, train_env, device): """Make TD3 agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.action_spec - if train_env.batch_size: - action_spec = action_spec[(0,) * len(train_env.batch_size)] - actor_net_kwargs = { - "num_cells": cfg.network.hidden_sizes, - "out_features": action_spec.shape[-1], - "activation_class": get_activation(cfg), - } + action_spec = train_env.action_spec_unbatched.to(device) - actor_net = MLP(**actor_net_kwargs) + actor_net = MLP( + num_cells=cfg.network.hidden_sizes, + out_features=action_spec.shape[-1], + activation_class=get_activation(cfg), + device=device, + ) in_keys_actor = in_keys - actor_module = SafeModule( + actor_module = TensorDictModule( actor_net, in_keys=in_keys_actor, - out_keys=[ - "param", - ], + out_keys=["param"], ) - actor = SafeSequential( + actor = TensorDictSequential( actor_module, TanhModule( in_keys=["param"], @@ -151,14 +142,11 @@ def make_td3_agent(cfg, train_env, device): ) # Define Critic Network - qvalue_net_kwargs = { - "num_cells": cfg.network.hidden_sizes, - "out_features": 1, - "activation_class": get_activation(cfg), - } - qvalue_net = MLP( - **qvalue_net_kwargs, + num_cells=cfg.network.hidden_sizes, + out_features=1, + activation_class=get_activation(cfg), + device=device, ) qvalue = ValueOperator( @@ -166,7 +154,7 @@ def make_td3_agent(cfg, train_env, device): module=qvalue_net, ) - model = nn.ModuleList([actor, qvalue]).to(device) + model = nn.ModuleList([actor, qvalue]) # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):