Skip to content

Commit

Permalink
[Feature] PPO compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: ada52e4e1f02e973e751bbbdaf5312cbf1dcf02c
Pull Request resolved: #2652
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 62e7df8 commit 3b099ec
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 135 deletions.
2 changes: 2 additions & 0 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_atari import eval_model, make_dqn_model, make_env

torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/dqn/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_cartpole import eval_model, make_dqn_model, make_env

torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="config_cartpole", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down
6 changes: 6 additions & 0 deletions sota-implementations/ppo/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ optim:
weight_decay: 0.0
max_grad_norm: 0.5
anneal_lr: True
device:

# loss
loss:
Expand All @@ -37,3 +38,8 @@ loss:
critic_coef: 1.0
entropy_coef: 0.01
loss_critic_type: l2

compile:
compile: False
compile_mode:
cudagraphs: False
6 changes: 6 additions & 0 deletions sota-implementations/ppo/config_mujoco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ optim:
lr: 3e-4
weight_decay: 0.0
anneal_lr: True
device:

# loss
loss:
Expand All @@ -34,3 +35,8 @@ loss:
critic_coef: 0.25
entropy_coef: 0.0
loss_critic_type: l2

compile:
compile: False
compile_mode:
cudagraphs: False
160 changes: 94 additions & 66 deletions sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
"""
from __future__ import annotations

import warnings

import hydra
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule

from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.record import VideoRecorder


@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821

import time

import torch.optim
import tqdm

Expand All @@ -32,7 +34,15 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_atari import eval_model, make_parallel_env, make_ppo_models

device = "cpu" if not torch.cuda.device_count() else "cuda"
torch.set_float32_matmul_precision("high")

device = cfg.optim.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand All @@ -41,9 +51,17 @@ def main(cfg: "DictConfig"): # noqa: F821
mini_batch_size = cfg.loss.mini_batch_size // frame_skip
test_interval = cfg.logger.test_interval // frame_skip

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create models (check utils_atari.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)
actor, critic = make_ppo_models(cfg.env.env_name, device=device)

# Create collector
collector = SyncDataCollector(
Expand All @@ -54,6 +72,8 @@ def main(cfg: "DictConfig"): # noqa: F821
device="cpu",
storing_device="cpu",
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode} if compile_mode else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

# Create data buffer
Expand Down Expand Up @@ -122,14 +142,54 @@ def main(cfg: "DictConfig"): # noqa: F821
# Main loop
collected_frames = 0
num_network_updates = 0
start_time = time.time()
pbar = tqdm.tqdm(total=total_frames)
num_mini_batches = frames_per_batch // mini_batch_size
total_network_updates = (
(total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches
)

sampling_start = time.time()
def update(batch, num_network_updates):
optim.zero_grad(set_to_none=True)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
num_network_updates += 1
# Get a data batch
batch = batch.to(device, non_blocking=True)

# Forward pass PPO loss
loss = loss_module(batch)
loss_sum = (
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
)
# Backward pass
loss_sum.backward()
torch.nn.utils.clip_grad_norm_(
loss_module.parameters(), max_norm=cfg_optim_max_grad_norm
)

# Update the networks
optim.step()
return loss.detach().set("alpha", alpha)


if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
adv_module = torch.compile(adv_module, mode=compile_mode)

if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
adv_module = CudaGraphModule(adv_module)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.loss.ppo_epochs
Expand All @@ -142,13 +202,16 @@ def main(cfg: "DictConfig"): # noqa: F821
cfg.loss.clip_epsilon = cfg_loss_clip_epsilon
losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches])

for i, data in enumerate(collector):
collector_iter = iter(collector)

for i in range(len(collector)):
with timeit("collecting"):
data = next(collector_iter)

log_info = {}
sampling_time = time.time() - sampling_start
frames_in_batch = data.numel()
collected_frames += frames_in_batch * frame_skip
pbar.update(data.numel())
pbar.update(frames_in_batch)

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
Expand All @@ -162,96 +225,61 @@ def main(cfg: "DictConfig"): # noqa: F821
}
)

training_start = time.time()
for j in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
data = adv_module(data.to(device, non_blocking=True))
data_reshape = data.reshape(-1)
# Update the data buffer
data_buffer.extend(data_reshape)

for k, batch in enumerate(data_buffer):

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
num_network_updates += 1
# Get a data batch
batch = batch.to(device, non_blocking=True)

# Forward pass PPO loss
loss = loss_module(batch)
losses[j, k] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
).detach()
loss_sum = (
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
)
# Backward pass
loss_sum.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=cfg_optim_max_grad_norm
)
with timeit("training"):
for j in range(cfg_loss_ppo_epochs):

# Update the networks
optim.step()
optim.zero_grad()
# Compute GAE
with torch.no_grad(), timeit("adv"):
data = adv_module(data.to(device))
with timeit("rb - extend"):
# Update the data buffer
data_reshape = data.reshape(-1)
data_buffer.extend(data_reshape)

for k, batch in enumerate(data_buffer):

loss = update(batch, num_network_updates=num_network_updates)
losses[j, k] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
)

# Get training losses and times
training_time = time.time() - training_start
losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses_mean.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
{
"train/lr": alpha * cfg_optim_lr,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
"train/clip_epsilon": alpha * cfg_loss_clip_epsilon,
"train/lr": loss["alpha"] * cfg_optim_lr,
"train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon,
}
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC), timeit("eval"):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
actor, test_env, num_episodes=cfg_logger_num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
{
"eval/reward": test_rewards.mean(),
"eval/time": eval_time,
}
)
actor.train()

if logger:
log_info.update(timeit.todict(prefix="time"))
for key, value in log_info.items():
logger.log_scalar(key, value, collected_frames)

collector.update_policy_weights_()
sampling_start = time.time()

collector.shutdown()
if not test_env.is_closed:
test_env.close()

end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
Loading

0 comments on commit 3b099ec

Please sign in to comment.