Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent dd6c018 commit c6b2fb1
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 85 deletions.
5 changes: 5 additions & 0 deletions sota-implementations/sac/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,8 @@ logger:
mode: online
eval_iter: 25000
video: False

compile:
compile: False
compile_mode:
cudagraphs: False
163 changes: 89 additions & 74 deletions sota-implementations/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""
from __future__ import annotations

import time
import warnings

import hydra

Expand All @@ -21,8 +21,11 @@
import torch.cuda
import tqdm
from tensordict import TensorDict
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule

from torchrl._utils import compile_with_warmup, timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers

from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
Expand Down Expand Up @@ -75,16 +78,27 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create SAC loss
loss_module, target_net_updater = make_loss_module(cfg, model)

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)
collector = make_collector(
cfg, train_env, exploration_policy, compile_mode=compile_mode
)

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
scratch_dir=cfg.replay_buffer.scratch_dir,
device="cpu",
device=device,
)

# Create optimizers
Expand All @@ -93,9 +107,36 @@ def main(cfg: "DictConfig"): # noqa: F821
optimizer_critic,
optimizer_alpha,
) = make_sac_optimizer(cfg, loss_module)
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
del optimizer_actor, optimizer_critic, optimizer_alpha

def update(sampled_tensordict):
# Compute loss
loss_td = loss_module(sampled_tensordict)

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
alpha_loss = loss_td["loss_alpha"]

(actor_loss + q_loss + alpha_loss).sum().backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

# Update qnet_target params
target_net_updater.step()
return loss_td.detach()

if cfg.compile.compile:
update = compile_with_warmup(update, mode=compile_mode, warmup=1)

if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -110,69 +151,48 @@ def main(cfg: "DictConfig"): # noqa: F821
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
for i, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
collector_iter = iter(collector)
total_iter = len(collector)

for i in range(total_iter):
timeit.printevery(num_prints=1000, total_count=total_iter, erase=True)

with timeit("collect"):
tensordict = next(collector_iter)

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
pbar.update(current_frames)

with timeit("rb - extend"):
tensordict = tensordict.reshape(-1)
# Add to replay buffer
replay_buffer.extend(tensordict)

collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
losses = TensorDict(batch_size=[num_updates])
for i in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
with timeit("train"):
if collected_frames >= init_random_frames:
losses = TensorDict(batch_size=[num_updates])
for i in range(num_updates):
with timeit("rb - sample"):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()

with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_td = update(sampled_tensordict).clone()
losses[i] = loss_td.select(
"loss_actor", "loss_qvalue", "loss_alpha"
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
loss_td = loss_module(sampled_tensordict)

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
alpha_loss = loss_td["loss_alpha"]

# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
optimizer_alpha.step()

losses[i] = loss_td.select(
"loss_actor", "loss_qvalue", "loss_alpha"
).detach()

# Update qnet_target params
target_net_updater.step()

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)
# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
Expand All @@ -184,46 +204,41 @@ def main(cfg: "DictConfig"): # noqa: F821
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
metrics_to_log["train/reward"] = episode_rewards
metrics_to_log["train/episode_length"] = episode_length.sum() / len(
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = losses.get("loss_qvalue").mean().item()
metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item()
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item()
metrics_to_log["train/alpha"] = loss_td["alpha"].item()
metrics_to_log["train/entropy"] = loss_td["entropy"].item()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time
losses = losses.mean()
metrics_to_log["train/q_loss"] = losses.get("loss_qvalue")
metrics_to_log["train/actor_loss"] = losses.get("loss_actor")
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha")
metrics_to_log["train/alpha"] = loss_td["alpha"]
metrics_to_log["train/entropy"] = loss_td["entropy"]

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
21 changes: 10 additions & 11 deletions sota-implementations/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def make_environment(cfg, logger=None):
# ---------------------------


def make_collector(cfg, train_env, actor_model_explore):
def make_collector(cfg, train_env, actor_model_explore, compile_mode):
"""Make collector."""
device = cfg.collector.device
if device in ("", None):
Expand All @@ -120,6 +120,8 @@ def make_collector(cfg, train_env, actor_model_explore):
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
device=device,
compile_policy={"mode": compile_mode} if compile_mode else False,
cudagraph_policy=cfg.compile.cudagraphs,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down Expand Up @@ -169,7 +171,7 @@ def make_sac_agent(cfg, train_env, eval_env, device):
"""Make SAC agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec_unbatched
action_spec = train_env.action_spec_unbatched.to(device)
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
Expand All @@ -188,7 +190,7 @@ def make_sac_agent(cfg, train_env, eval_env, device):
actor_extractor = NormalParamExtractor(
scale_mapping=f"biased_softplus_{cfg.network.default_policy_scale}",
scale_lb=cfg.network.scale_lb,
)
).to(device)
actor_net = nn.Sequential(actor_net, actor_extractor)

in_keys_actor = in_keys
Expand All @@ -211,22 +213,19 @@ def make_sac_agent(cfg, train_env, eval_env, device):
)

# Define Critic Network
qvalue_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": 1,
"activation_class": get_activation(cfg),
}

qvalue_net = MLP(
**qvalue_net_kwargs,
num_cells=cfg.network.hidden_sizes,
out_features=1,
activation_class=get_activation(cfg),
device=device,
)

qvalue = ValueOperator(
in_keys=["action"] + in_keys,
module=qvalue_net,
)

model = nn.ModuleList([actor, qvalue]).to(device)
model = nn.ModuleList([actor, qvalue])

# init nets
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
Expand Down

0 comments on commit c6b2fb1

Please sign in to comment.