Skip to content

Commit

Permalink
[Feature] A2C compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 6f2f140f32f0aaf886f605bd1eb538428fddd901
Pull Request resolved: #2464
  • Loading branch information
vmoens committed Nov 11, 2024
1 parent d894358 commit 0611c42
Show file tree
Hide file tree
Showing 18 changed files with 491 additions and 282 deletions.
2 changes: 1 addition & 1 deletion benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
) # Anything from 2.5, incl. nightlies, allows for fullgraph


@pytest.fixture(scope="module")
@pytest.fixture(scope="module", autouse=True)
def set_default_device():
cur_device = torch.get_default_device()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand Down
158 changes: 99 additions & 59 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import hydra
from tensordict.nn import CudaGraphModule
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder

Expand All @@ -15,17 +16,21 @@ def main(cfg: "DictConfig"): # noqa: F821
import torch.optim
import tqdm

from tensordict import TensorDict
from torchrl._utils import timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.objectives import A2CLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_atari import eval_model, make_parallel_env, make_ppo_models

device = "cpu" if not torch.cuda.device_count() else "cuda"
device = cfg.loss.device
if not device:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
else:
device = torch.device(device)

# Correct for frame_skip
frame_skip = 4
Expand All @@ -35,28 +40,12 @@ def main(cfg: "DictConfig"): # noqa: F821
test_interval = cfg.logger.test_interval // frame_skip

# Create models (check utils_atari.py)
actor, critic, critic_head = make_ppo_models(cfg.env.env_name)
actor, critic, critic_head = (
actor.to(device),
critic.to(device),
critic_head.to(device),
)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
)
actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device)

# Create data buffer
sampler = SamplerWithoutReplacement()
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(frames_per_batch),
storage=LazyTensorStorage(frames_per_batch, device=device),
sampler=sampler,
batch_size=mini_batch_size,
)
Expand All @@ -67,6 +56,7 @@ def main(cfg: "DictConfig"): # noqa: F821
lmbda=cfg.loss.gae_lambda,
value_network=critic,
average_gae=True,
vectorized=not cfg.loss.compile,
)
loss_module = A2CLoss(
actor_network=actor,
Expand All @@ -83,9 +73,10 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create optimizer
optim = torch.optim.Adam(
loss_module.parameters(),
lr=cfg.optim.lr,
lr=torch.tensor(cfg.optim.lr, device=device),
weight_decay=cfg.optim.weight_decay,
eps=cfg.optim.eps,
capturable=device.type == "cuda",
)

# Create logger
Expand Down Expand Up @@ -115,16 +106,72 @@ def main(cfg: "DictConfig"): # noqa: F821
)
test_env.eval()

# update function
def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
# Forward pass A2C loss
loss = loss_module(batch)

loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]

# Backward pass
loss_sum.backward()
gn = torch.nn.utils.clip_grad_norm_(
loss_module.parameters(), max_norm=max_grad_norm
)

# Update the networks
optim.step()
optim.zero_grad(set_to_none=True)

return (
loss.select("loss_critic", "loss_entropy", "loss_objective")
.detach()
.set("grad_norm", gn)
)

compile_mode = None
if cfg.loss.compile:
compile_mode = cfg.loss.compile_mode
if compile_mode in ("", None):
if cfg.loss.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
adv_module = torch.compile(adv_module, mode=compile_mode)

if cfg.loss.cudagraphs:
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
adv_module = CudaGraphModule(adv_module)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
policy_device=device,
compile_policy={"mode": compile_mode} if cfg.loss.compile else False,
cudagraph_policy=cfg.loss.cudagraphs,
)

# Main loop
collected_frames = 0
num_network_updates = 0
start_time = time.time()
pbar = tqdm.tqdm(total=total_frames)
num_mini_batches = frames_per_batch // mini_batch_size
total_network_updates = (total_frames // frames_per_batch) * num_mini_batches
lr = cfg.optim.lr

sampling_start = time.time()
for i, data in enumerate(collector):
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
torch.compiler.cudagraph_mark_step_begin()
data = next(c_iter)

log_info = {}
sampling_time = time.time() - sampling_start
Expand All @@ -144,61 +191,55 @@ def main(cfg: "DictConfig"): # noqa: F821
}
)

losses = TensorDict({}, batch_size=[num_mini_batches])
losses = []
training_start = time.time()

# Compute GAE
with torch.no_grad():
with torch.no_grad(), timeit("advantage"):
data = adv_module(data)
data_reshape = data.reshape(-1)

# Update the data buffer
data_buffer.extend(data_reshape)

for k, batch in enumerate(data_buffer):

# Get a data batch
batch = batch.to(device)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
if cfg.optim.anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in optim.param_groups:
group["lr"] = cfg.optim.lr * alpha
num_network_updates += 1

# Forward pass A2C loss
loss = loss_module(batch)
losses[k] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
).detach()
loss_sum = (
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
)
with timeit("emptying"):
data_buffer.empty()
with timeit("extending"):
data_buffer.extend(data_reshape)

# Backward pass
loss_sum.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm
)
with timeit("optim"):
for batch in data_buffer:

# Linearly decrease the learning rate and clip epsilon
with timeit("optim - lr"):
alpha = 1.0
if cfg.optim.anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in optim.param_groups:
group["lr"].copy_(lr * alpha)

# Update the networks
optim.step()
optim.zero_grad()
num_network_updates += 1

with timeit("optim - update"):
torch.compiler.cudagraph_mark_step_begin()
loss = update(batch)
losses.append(loss)

# Get training losses
training_time = time.time() - training_start
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
losses = torch.stack(losses).float().mean()

for key, value in losses.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
{
"train/lr": alpha * cfg.optim.lr,
"train/lr": lr * alpha,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
**timeit.todict(prefix="time"),
}
)
if i % 200 == 0:
timeit.print()
timeit.erase()

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
Expand All @@ -223,7 +264,6 @@ def main(cfg: "DictConfig"): # noqa: F821
for key, value in log_info.items():
logger.log_scalar(key, value, collected_frames)

collector.update_policy_weights_()
sampling_start = time.time()

collector.shutdown()
Expand Down
Loading

0 comments on commit 0611c42

Please sign in to comment.