Skip to content

Commit

Permalink
[Feature] IQL compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 29e638710dcb9bf2d84196e23df7954be2210053
Pull Request resolved: #2649
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent c57c98e commit 79986a6
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 185 deletions.
169 changes: 95 additions & 74 deletions sota-implementations/iql/discrete_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
Expand Down Expand Up @@ -87,8 +91,19 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create model
model = make_discrete_iql_model(cfg, train_env, eval_env, device)

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create collector
collector = make_collector(cfg, train_env, actor_model_explore=model[0])
collector = make_collector(
cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode
)

# Create loss
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model)
Expand All @@ -97,6 +112,34 @@ def main(cfg: "DictConfig"): # noqa: F821
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
cfg.optim, loss_module
)
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value)
del optimizer_actor, optimizer_critic, optimizer_value

def update(sampled_tensordict):
optimizer.zero_grad(set_to_none=True)
# compute losses
actor_loss, _ = loss_module.actor_loss(sampled_tensordict)
value_loss, _ = loss_module.value_loss(sampled_tensordict)
q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict)
(actor_loss + value_loss + q_loss).backward()
optimizer.step()

# update qnet_target params
target_net_updater.step()
return TensorDict(
metadata.update(
{"actor_loss": actor_loss, "value_loss": value_loss, "q_loss": q_loss}
)
).detach()

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
collected_frames = 0
Expand All @@ -112,103 +155,81 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.collector.max_frames_per_traj
sampling_start = start_time = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start
pbar.update(tensordict.numel())

collector_iter = iter(collector)
for _ in range(len(collector)):
with timeit("collection"):
tensordict = next(collector_iter)
current_frames = tensordict.numel()
pbar.update(current_frames)

# update weights of the inference policy
collector.update_policy_weights_()

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# add to replay buffer
replay_buffer.extend(tensordict.cpu())
with timeit("buffer - extend"):
tensordict = tensordict.reshape(-1)

# add to replay buffer
replay_buffer.extend(tensordict)
collected_frames += current_frames

# optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
for _ in range(num_updates):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict
# compute losses
actor_loss, _ = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

value_loss, _ = loss_module.value_loss(sampled_tensordict)
optimizer_value.zero_grad()
value_loss.backward()
optimizer_value.step()

q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict)
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# update qnet_target params
target_net_updater.step()

# update priority
if prb:
sampled_tensordict.set(
loss_module.tensor_keys.priority,
metadata.pop("td_error").detach().max(0).values,
)
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
with timeit("training"):
if collected_frames >= init_random_frames:
for _ in range(num_updates):
# sample from replay buffer
with timeit("buffer - sample"):
sampled_tensordict = replay_buffer.sample().to(device)

with timeit("training - update"):
metadata = update(sampled_tensordict)
# update priority
if prb:
sampled_tensordict.set(
loss_module.tensor_keys.priority,
metadata.pop("td_error").detach().max(0).values,
)
replay_buffer.update_priority(sampled_tensordict)

episode_rewards = tensordict["next", "episode_reward"][
tensordict["next", "done"]
]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][
tensordict["next", "done"]
]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = q_loss.detach()
metrics_to_log["train/actor_loss"] = actor_loss.detach()
metrics_to_log["train/value_loss"] = value_loss.detach()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time

# Logging
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][
tensordict["next", "done"]
]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = metadata["q_loss"]
metrics_to_log["train/actor_loss"] = metadata["actor_loss"]
metrics_to_log["train/value_loss"] = metadata["value_loss"]
metrics_to_log.update(timeit.todict(prefix="time"))
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()
timeit.erase()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
84 changes: 47 additions & 37 deletions sota-implementations/iql/iql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
Expand Down Expand Up @@ -85,54 +88,62 @@ def main(cfg: "DictConfig"): # noqa: F821
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
cfg.optim, loss_module
)
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)

gradient_steps = cfg.optim.gradient_steps
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
for i in range(gradient_steps):
pbar.update(1)
# sample data
data = replay_buffer.sample()

if data.device != device:
data = data.to(device, non_blocking=True)

def update(data):
optimizer.zero_grad(set_to_none=True)
# compute losses
loss_info = loss_module(data)
actor_loss = loss_info["loss_actor"]
value_loss = loss_info["loss_value"]
q_loss = loss_info["loss_qvalue"]

optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

optimizer_value.zero_grad()
value_loss.backward()
optimizer_value.step()

optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
(actor_loss + value_loss + q_loss).backward()
optimizer.step()

# update qnet_target params
target_net_updater.step()
return loss_info.detach()

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

pbar = tqdm.tqdm(range(cfg.optim.gradient_steps))

evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
for i in pbar:
# sample data
with timeit("sample"):
data = replay_buffer.sample()
data = data.to(device)

# log metrics
to_log = {
"loss_actor": actor_loss.item(),
"loss_qvalue": q_loss.item(),
"loss_value": value_loss.item(),
}
with timeit("update"):
loss_info = update(data)

# evaluation
to_log = loss_info.to_dict()
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
Expand All @@ -147,7 +158,6 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_env.close()
if not train_env.is_closed:
train_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 79986a6

Please sign in to comment.