Skip to content

Commit

Permalink
[Feature] CROSSQ compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 223f1c7d4ffbd2086655391083875022035da567
Pull Request resolved: #2554
  • Loading branch information
vmoens committed Nov 12, 2024
1 parent 87ce4b3 commit ef9dabc
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 93 deletions.
7 changes: 5 additions & 2 deletions sota-implementations/crossq/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ collector:
init_random_frames: 25000
frames_per_batch: 1000
init_env_steps: 1000
device: cpu
device:
env_per_collector: 1
reset_at_each_iter: False

Expand Down Expand Up @@ -46,7 +46,10 @@ network:
actor_activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"
device:
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
Expand Down
190 changes: 112 additions & 78 deletions sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
The helper functions are coded in the utils.py associated with this script.
"""
import time

import hydra

import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -71,8 +73,25 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create CrossQ loss
loss_module = make_loss_module(cfg, model)

compile_mode = None
if cfg.network.compile:
if cfg.network.compile_mode not in (None, ""):
compile_mode = cfg.network.compile_mode
elif cfg.network.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy.eval(), device=device)
collector = make_collector(
cfg,
train_env,
exploration_policy.eval(),
device=device,
compile=cfg.network.compile,
compile_mode=compile_mode,
cudagraph=cfg.network.cudagraphs,
)

# Create replay buffer
replay_buffer = make_replay_buffer(
Expand All @@ -90,8 +109,58 @@ def main(cfg: "DictConfig"): # noqa: F821
optimizer_alpha,
) = make_crossQ_optimizer(cfg, loss_module)

def update_qloss(sampled_tensordict):
optimizer_critic.zero_grad(set_to_none=True)
td_loss = {}
q_loss, value_meta = loss_module.qvalue_loss(sampled_tensordict)
sampled_tensordict.set(loss_module.tensor_keys.priority, value_meta["td_error"])
q_loss = q_loss.mean()

# Update critic
q_loss.backward()
optimizer_critic.step()
td_loss["loss_qvalue"] = q_loss
td_loss["loss_actor"] = float("nan")
td_loss["loss_alpha"] = float("nan")
return TensorDict(td_loss).detach()

def update_all(
sampled_tensordict: TensorDict, update_qloss=update_qloss
): # bind update_qloss
# Compute loss
td_loss = update_qloss(sampled_tensordict)

actor_loss, metadata_actor = loss_module.actor_loss(sampled_tensordict)
actor_loss = actor_loss.mean()
alpha_loss = loss_module.alpha_loss(
log_prob=metadata_actor["log_prob"].detach()
).mean()

# Update actor
(actor_loss + actor_loss).backward()
optimizer_actor.step()

# Update alpha
optimizer_alpha.step()

td_loss["loss_actor"] = actor_loss
td_loss["loss_alpha"] = alpha_loss

return TensorDict(td_loss).detach()

if compile_mode:
update_all = torch.compile(update_all, mode=compile_mode)
update_qloss = torch.compile(update_qloss, mode=compile_mode)
if cfg.network.cudagraphs:
update_all = CudaGraphModule(update_all, warmup=50)
update_qloss = CudaGraphModule(update_qloss, warmup=50)

def update(sampled_tensordict: TensorDict, update_actor: bool):
if update_actor:
return update_all(sampled_tensordict)
return update_qloss(sampled_tensordict)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -106,127 +175,92 @@ def main(cfg: "DictConfig"): # noqa: F821
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
update_counter = 0
delayed_updates = cfg.optim.policy_update_delay
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
torch.compiler.cudagraph_mark_step_begin()
tensordict = next(c_iter)

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
pbar.update(current_frames)
tensordict = tensordict.reshape(-1)

with timeit("rb - extend"):
# Add to replay buffer
replay_buffer.extend(tensordict)
collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
alpha_losses,
q_losses,
) = ([], [], [])
tds = []
for _ in range(num_updates):

# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(device)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict)
q_loss = q_loss.mean()
# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.detach().item())

if update_actor:
actor_loss, metadata_actor = loss_module.actor_loss(
sampled_tensordict
)
actor_loss = actor_loss.mean()
alpha_loss = loss_module.alpha_loss(
log_prob=metadata_actor["log_prob"]
).mean()

# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
optimizer_alpha.step()

actor_losses.append(actor_loss.detach().item())
alpha_losses.append(alpha_loss.detach().item())

with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample().to(device)
with timeit("update"):
td_loss = update(sampled_tensordict, update_actor=update_actor)
tds.append(td_loss)
# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
tds = TensorDict.stack(tds).nanmean()
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
)
episode_rewards = tensordict["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses).item()
metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item()
metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time

# Logging
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if i % 20 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/actor_loss"] = tds["loss_actor"]
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()
if i % 20 == 0:
timeit.print()
timeit.erase()

collector.shutdown()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
17 changes: 13 additions & 4 deletions sota-implementations/crossq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,15 @@ def make_environment(cfg):
# ---------------------------


def make_collector(cfg, train_env, actor_model_explore, device):
def make_collector(
cfg,
train_env,
actor_model_explore,
device,
compile=False,
compile_mode=None,
cudagraph=False,
):
"""Make collector."""
collector = SyncDataCollector(
train_env,
Expand All @@ -99,6 +107,8 @@ def make_collector(cfg, train_env, actor_model_explore, device):
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
device=device,
compile_policy={"mode": compile_mode} if compile else False,
cudagraph_policy=cudagraph,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down Expand Up @@ -147,9 +157,7 @@ def make_crossQ_agent(cfg, train_env, device):
"""Make CrossQ agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.single_action_spec
actor_net_kwargs = {
"num_cells": cfg.network.actor_hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
Expand All @@ -169,6 +177,7 @@ def make_crossQ_agent(cfg, train_env, device):
"low": action_spec.space.low,
"high": action_spec.space.high,
"tanh_loc": False,
"safe_tanh": not cfg.network.compile,
}

actor_extractor = NormalParamExtractor(
Expand Down
14 changes: 5 additions & 9 deletions torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def __init__(
self._action_spec = action_spec
self._make_vmap()
self.reduction = reduction
# init target entropy
_ = self.target_entropy

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
Expand Down Expand Up @@ -513,15 +515,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
**metadata_actor,
**value_metadata,
}
td_out = TensorDict(out, [])
# td_out = td_out.named_apply(
# lambda name, value: (
# _reduce(value, reduction=self.reduction)
# if name.startswith("loss_")
# else value
# ),
# batch_size=[],
# )
td_out = TensorDict(out)
return td_out

@property
Expand All @@ -543,6 +537,7 @@ def actor_loss(
Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action.
"""
tensordict = tensordict.copy()
with set_exploration_type(
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
Expand Down Expand Up @@ -584,6 +579,7 @@ def qvalue_loss(
Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing
the detached `"td_error"` to be used for prioritized sampling.
"""
tensordict = tensordict.copy()
# # compute next action
with torch.no_grad():
with set_exploration_type(
Expand Down

0 comments on commit ef9dabc

Please sign in to comment.