Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 18, 2024
2 parents 0ad33aa + 99509bc commit f3e5540
Show file tree
Hide file tree
Showing 19 changed files with 326 additions and 207 deletions.
14 changes: 12 additions & 2 deletions sota-implementations/a2c/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@ Please note that each example is independent of each other for the sake of simpl
You can execute the A2C algorithm on Atari environments by running the following command:

```bash
python a2c_atari.py
python a2c_atari.py compile.compile=1 compile.cudagraphs=1
```


You can execute the A2C algorithm on MuJoCo environments by running the following command:

```bash
python a2c_mujoco.py
python a2c_mujoco.py compile.compile=1 compile.cudagraphs=1
```

## Runtimes

Runtimes when executed on H100:

| Environment | Eager | Compile | Compile+cudagraphs |
|-------------|-----------|-----------|--------------------|
| MUJOCO | < 25 mins | < 23 mins | < 20 mins |
| ATARI | < 85 mins | < 60 mins | < 45 mins |
195 changes: 114 additions & 81 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,37 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import hydra
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder
import torch

torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821

import time
from copy import deepcopy

import torch.optim
import tqdm
from tensordict import from_module
from tensordict.nn import CudaGraphModule

from tensordict import TensorDict
from torchrl._utils import timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.objectives import A2CLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_atari import eval_model, make_parallel_env, make_ppo_models

device = "cpu" if not torch.cuda.device_count() else "cuda"
device = cfg.loss.device
if not device:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
else:
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand All @@ -35,28 +43,16 @@ def main(cfg: "DictConfig"): # noqa: F821
test_interval = cfg.logger.test_interval // frame_skip

# Create models (check utils_atari.py)
actor, critic, critic_head = make_ppo_models(cfg.env.env_name)
actor, critic, critic_head = (
actor.to(device),
critic.to(device),
critic_head.to(device),
)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
)
actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device)
with from_module(actor).data.to("meta").to_module(actor):
actor_eval = deepcopy(actor)
actor_eval.eval()
from_module(actor).data.to_module(actor_eval)

# Create data buffer
sampler = SamplerWithoutReplacement()
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(frames_per_batch),
storage=LazyTensorStorage(frames_per_batch, device=device),
sampler=sampler,
batch_size=mini_batch_size,
)
Expand All @@ -67,6 +63,8 @@ def main(cfg: "DictConfig"): # noqa: F821
lmbda=cfg.loss.gae_lambda,
value_network=critic,
average_gae=True,
vectorized=not cfg.compile.compile,
device=device,
)
loss_module = A2CLoss(
actor_network=actor,
Expand All @@ -83,9 +81,10 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create optimizer
optim = torch.optim.Adam(
loss_module.parameters(),
lr=cfg.optim.lr,
lr=torch.tensor(cfg.optim.lr, device=device),
weight_decay=cfg.optim.weight_decay,
eps=cfg.optim.eps,
capturable=device.type == "cuda",
)

# Create logger
Expand Down Expand Up @@ -115,19 +114,71 @@ def main(cfg: "DictConfig"): # noqa: F821
)
test_env.eval()

# update function
def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
# Forward pass A2C loss
loss = loss_module(batch)

loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]

# Backward pass
loss_sum.backward()
gn = torch.nn.utils.clip_grad_norm_(
loss_module.parameters(), max_norm=max_grad_norm
)

# Update the networks
optim.step()
optim.zero_grad(set_to_none=True)

return (
loss.select("loss_critic", "loss_entropy", "loss_objective")
.detach()
.set("grad_norm", gn)
)

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
adv_module = torch.compile(adv_module, mode=compile_mode)

if cfg.compile.cudagraphs:
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
adv_module = CudaGraphModule(adv_module)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
policy_device=device,
compile_policy={"mode": compile_mode} if cfg.compile.compile else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

# Main loop
collected_frames = 0
num_network_updates = 0
start_time = time.time()
pbar = tqdm.tqdm(total=total_frames)
num_mini_batches = frames_per_batch // mini_batch_size
total_network_updates = (total_frames // frames_per_batch) * num_mini_batches
lr = cfg.optim.lr

sampling_start = time.time()
for i, data in enumerate(collector):
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
data = next(c_iter)

log_info = {}
sampling_time = time.time() - sampling_start
frames_in_batch = data.numel()
collected_frames += frames_in_batch * frame_skip
pbar.update(data.numel())
Expand All @@ -144,94 +195,76 @@ def main(cfg: "DictConfig"): # noqa: F821
}
)

losses = TensorDict(batch_size=[num_mini_batches])
training_start = time.time()
losses = []

# Compute GAE
with torch.no_grad():
with torch.no_grad(), timeit("advantage"):
torch.compiler.cudagraph_mark_step_begin()
data = adv_module(data)
data_reshape = data.reshape(-1)

# Update the data buffer
data_buffer.extend(data_reshape)

for k, batch in enumerate(data_buffer):

# Get a data batch
batch = batch.to(device)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
if cfg.optim.anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in optim.param_groups:
group["lr"] = cfg.optim.lr * alpha
num_network_updates += 1

# Forward pass A2C loss
loss = loss_module(batch)
losses[k] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
).detach()
loss_sum = (
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
)
with timeit("rb - emptying"):
data_buffer.empty()
with timeit("rb - extending"):
data_buffer.extend(data_reshape)

# Backward pass
loss_sum.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm
)
with timeit("optim"):
for batch in data_buffer:

# Update the networks
optim.step()
optim.zero_grad()
# Linearly decrease the learning rate and clip epsilon
with timeit("optim - lr"):
alpha = 1.0
if cfg.optim.anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in optim.param_groups:
group["lr"].copy_(lr * alpha)

num_network_updates += 1

with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss = update(batch).clone()
losses.append(loss)

# Get training losses
training_time = time.time() - training_start
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
losses = torch.stack(losses).float().mean()

for key, value in losses.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
{
"train/lr": alpha * cfg.optim.lr,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
"train/lr": lr * alpha,
}
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
with torch.no_grad(), set_exploration_type(
ExplorationType.DETERMINISTIC
), timeit("eval"):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
actor, test_env, num_episodes=cfg.logger.num_test_episodes
actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
{
"test/reward": test_rewards.mean(),
"test/eval_time": eval_time,
}
)
actor.train()
if i % 200 == 0:
log_info.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger:
for key, value in log_info.items():
logger.log_scalar(key, value, collected_frames)

collector.update_policy_weights_()
sampling_start = time.time()

collector.shutdown()
if not test_env.is_closed:
test_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit f3e5540

Please sign in to comment.