Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] TD3 compatibility with compile #2656

Closed
wants to merge 19 commits into from
5 changes: 5 additions & 0 deletions sota-implementations/td3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,8 @@ logger:
mode: online
eval_iter: 25000
video: False

compile:
compile: False
compile_mode:
cudagraphs: False
194 changes: 115 additions & 79 deletions sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule

from torchrl._utils import compile_with_warmup, timeit

from torchrl.envs.utils import ExplorationType, set_exploration_type

Expand All @@ -36,6 +38,9 @@
)


torch.set_float32_matmul_precision("high")


@hydra.main(version_base="1.1", config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
device = cfg.network.device
Expand All @@ -44,7 +49,8 @@ def main(cfg: "DictConfig"): # noqa: F821
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
device = torch.device(device)
else:
device = torch.device(device)

# Create logger
exp_name = generate_exp_name("TD3", cfg.logger.exp_name)
Expand All @@ -67,31 +73,83 @@ def main(cfg: "DictConfig"): # noqa: F821
np.random.seed(cfg.env.seed)

# Create environments
train_env, eval_env = make_environment(cfg, logger=logger)
train_env, eval_env = make_environment(cfg, logger=logger, device=device)

# Create agent
model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device)

# Create TD3 loss
loss_module, target_net_updater = make_loss_module(cfg, model)

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)
collector = make_collector(
cfg,
train_env,
exploration_policy,
compile_mode=compile_mode,
device=device,
)

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
scratch_dir=cfg.replay_buffer.scratch_dir,
device="cpu",
device=device,
compile=bool(compile_mode),
)

# Create optimizers
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)

prb = cfg.replay_buffer.prb

def update(sampled_tensordict, update_actor, prb=prb):

# Compute loss
q_loss, *_ = loss_module.value_loss(sampled_tensordict)

# Update critic
q_loss.backward()
optimizer_critic.step()
optimizer_critic.zero_grad(set_to_none=True)

# Update actor
if update_actor:
actor_loss, *_ = loss_module.actor_loss(sampled_tensordict)

actor_loss.backward()
optimizer_actor.step()
optimizer_actor.zero_grad(set_to_none=True)

# Update target params
target_net_updater.step()
else:
actor_loss = q_loss.new_zeros(())

return q_loss.detach(), actor_loss.detach()

if cfg.compile.compile:
update = compile_with_warmup(update, mode=compile_mode, warmup=1)

if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -102,76 +160,60 @@ def main(cfg: "DictConfig"): # noqa: F821
* cfg.optim.utd_ratio
)
delayed_updates = cfg.optim.policy_update_delay
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
update_counter = 0

sampling_start = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start
exploration_policy[1].step(tensordict.numel())
collector_iter = iter(collector)
total_iter = len(collector)

for _ in range(total_iter):
timeit.printevery(num_prints=1000, total_count=total_iter, erase=True)

with timeit("collect"):
tensordict = next(collector_iter)

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
pbar.update(current_frames)

with timeit("rb - extend"):
# Add to replay buffer
tensordict = tensordict.reshape(-1)
replay_buffer.extend(tensordict)

collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
for _ in range(num_updates):

# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0

# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
q_loss, *_ = loss_module.value_loss(sampled_tensordict)

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.item())

# Update actor
if update_actor:
actor_loss, *_ = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

actor_losses.append(actor_loss.item())

# Update target params
target_net_updater.step()

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
with timeit("train"):
# Optimization steps
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
for _ in range(num_updates):
# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0

with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample()
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
q_loss, actor_loss = update(sampled_tensordict, update_actor)

# Update priority
if prb:
with timeit("rb - priority"):
replay_buffer.update_priority(sampled_tensordict)

q_losses.append(q_loss.clone())
if update_actor:
actor_losses.append(actor_loss.clone())

episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
Expand All @@ -183,45 +225,39 @@ def main(cfg: "DictConfig"): # noqa: F821
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
metrics_to_log["train/reward"] = episode_rewards.mean()
metrics_to_log["train/episode_length"] = episode_length.sum() / len(
episode_length
)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/q_loss"] = torch.stack(q_losses).mean()
if update_actor:
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time
metrics_to_log["train/a_loss"] = torch.stack(actor_losses).mean()

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
exploration_policy,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
Loading
Loading