Skip to content

Commit

Permalink
[Feature] TD3 compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: bf4ac88e13e30edf83f34cd838f3a82d323411ba
Pull Request resolved: #2656
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent 25ad990 commit 8bb9f50
Show file tree
Hide file tree
Showing 7 changed files with 481 additions and 450 deletions.
5 changes: 5 additions & 0 deletions sota-implementations/td3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,8 @@ logger:
mode: online
eval_iter: 25000
video: False

compile:
compile: False
compile_mode:
cudagraphs: False
192 changes: 113 additions & 79 deletions sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule

from torchrl._utils import compile_with_warmup, timeit

from torchrl.envs.utils import ExplorationType, set_exploration_type

Expand All @@ -36,6 +38,9 @@
)


torch.set_float32_matmul_precision("high")


@hydra.main(version_base="1.1", config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
device = cfg.network.device
Expand All @@ -44,7 +49,8 @@ def main(cfg: "DictConfig"): # noqa: F821
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
device = torch.device(device)
else:
device = torch.device(device)

# Create logger
exp_name = generate_exp_name("TD3", cfg.logger.exp_name)
Expand All @@ -67,31 +73,88 @@ def main(cfg: "DictConfig"): # noqa: F821
np.random.seed(cfg.env.seed)

# Create environments
train_env, eval_env = make_environment(cfg, logger=logger)
train_env, eval_env = make_environment(cfg, logger=logger, device=device)

# Create agent
model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device)

# Create TD3 loss
loss_module, target_net_updater = make_loss_module(cfg, model)

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)
collector = make_collector(
cfg,
train_env,
exploration_policy,
compile_mode=compile_mode,
device=device,
)

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
scratch_dir=cfg.replay_buffer.scratch_dir,
device="cpu",
device=device,
compile=bool(compile_mode),
)

# Create optimizers
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)

prb = cfg.replay_buffer.prb

def update(update_actor, prb=prb):
sampled_tensordict = replay_buffer.sample()

# Compute loss
q_loss, *_ = loss_module.value_loss(sampled_tensordict)

# Update critic
q_loss.backward()
optimizer_critic.step()
optimizer_critic.zero_grad(set_to_none=True)

# Update actor
if update_actor:
actor_loss, *_ = loss_module.actor_loss(sampled_tensordict)

actor_loss.backward()
optimizer_actor.step()
optimizer_actor.zero_grad(set_to_none=True)

# Update target params
target_net_updater.step()
else:
actor_loss = q_loss.new_zeros(())

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

return q_loss.detach(), actor_loss.detach()

if cfg.compile.compile:
update = compile_with_warmup(update, mode=compile_mode, warmup=1)

if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -102,76 +165,53 @@ def main(cfg: "DictConfig"): # noqa: F821
* cfg.optim.utd_ratio
)
delayed_updates = cfg.optim.policy_update_delay
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
update_counter = 0

sampling_start = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start
exploration_policy[1].step(tensordict.numel())
collector_iter = iter(collector)
total_iter = len(collector)

for _ in range(total_iter):
timeit.printevery(num_prints=1000, total_count=total_iter, erase=True)

with timeit("collect"):
tensordict = next(collector_iter)

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
pbar.update(current_frames)

with timeit("rb - extend"):
# Add to replay buffer
tensordict = tensordict.reshape(-1)
replay_buffer.extend(tensordict)

collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
for _ in range(num_updates):

# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0

# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
q_loss, *_ = loss_module.value_loss(sampled_tensordict)

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.item())

# Update actor
if update_actor:
actor_loss, *_ = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

actor_losses.append(actor_loss.item())

# Update target params
target_net_updater.step()

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
with timeit("train"):
# Optimization steps
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
for _ in range(num_updates):
# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0

with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
q_loss, actor_loss = update(update_actor)

q_losses.append(q_loss.clone())
if update_actor:
actor_losses.append(actor_loss.clone())

episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
Expand All @@ -183,45 +223,39 @@ def main(cfg: "DictConfig"): # noqa: F821
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
metrics_to_log["train/reward"] = episode_rewards.mean()
metrics_to_log["train/episode_length"] = episode_length.sum() / len(
episode_length
)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/q_loss"] = torch.stack(q_losses).mean()
if update_actor:
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time
metrics_to_log["train/a_loss"] = torch.stack(actor_losses).mean()

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
exploration_policy,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 8bb9f50

Please sign in to comment.