Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] IQL compatibility with compile #2649

Merged
merged 38 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 99 additions & 75 deletions sota-implementations/iql/discrete_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
Expand All @@ -37,6 +41,9 @@
)


torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="discrete_iql")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.backend).set()
Expand Down Expand Up @@ -87,16 +94,54 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create model
model = make_discrete_iql_model(cfg, train_env, eval_env, device)

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create collector
collector = make_collector(cfg, train_env, actor_model_explore=model[0])
collector = make_collector(
cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode
)

# Create loss
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model)
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device)

# Create optimizer
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
cfg.optim, loss_module
)
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value)
del optimizer_actor, optimizer_critic, optimizer_value

def update(sampled_tensordict):
optimizer.zero_grad(set_to_none=True)
# compute losses
actor_loss, _ = loss_module.actor_loss(sampled_tensordict)
value_loss, _ = loss_module.value_loss(sampled_tensordict)
q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict)
(actor_loss + value_loss + q_loss).backward()
optimizer.step()

# update qnet_target params
target_net_updater.step()
metadata.update(
{"actor_loss": actor_loss, "value_loss": value_loss, "q_loss": q_loss}
)
return TensorDict(metadata).detach()

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
collected_frames = 0
Expand All @@ -112,103 +157,82 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.collector.max_frames_per_traj
sampling_start = start_time = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start
pbar.update(tensordict.numel())

collector_iter = iter(collector)
for _ in range(len(collector)):
with timeit("collection"):
tensordict = next(collector_iter)
current_frames = tensordict.numel()
pbar.update(current_frames)

# update weights of the inference policy
collector.update_policy_weights_()

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# add to replay buffer
replay_buffer.extend(tensordict.cpu())
with timeit("buffer - extend"):
tensordict = tensordict.reshape(-1)

# add to replay buffer
replay_buffer.extend(tensordict)
collected_frames += current_frames

# optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
for _ in range(num_updates):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict
# compute losses
actor_loss, _ = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

value_loss, _ = loss_module.value_loss(sampled_tensordict)
optimizer_value.zero_grad()
value_loss.backward()
optimizer_value.step()

q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict)
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# update qnet_target params
target_net_updater.step()

# update priority
if prb:
sampled_tensordict.set(
loss_module.tensor_keys.priority,
metadata.pop("td_error").detach().max(0).values,
)
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
with timeit("training"):
if collected_frames >= init_random_frames:
for _ in range(num_updates):
# sample from replay buffer
with timeit("buffer - sample"):
sampled_tensordict = replay_buffer.sample().to(device)

with timeit("training - update"):
torch.compiler.cudagraph_mark_step_begin()
metadata = update(sampled_tensordict)
# update priority
if prb:
sampled_tensordict.set(
loss_module.tensor_keys.priority,
metadata.pop("td_error").detach().max(0).values,
)
replay_buffer.update_priority(sampled_tensordict)

episode_rewards = tensordict["next", "episode_reward"][
tensordict["next", "done"]
]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][
tensordict["next", "done"]
]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = q_loss.detach()
metrics_to_log["train/actor_loss"] = actor_loss.detach()
metrics_to_log["train/value_loss"] = value_loss.detach()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time

# Logging
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][
tensordict["next", "done"]
]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = metadata["q_loss"]
metrics_to_log["train/actor_loss"] = metadata["actor_loss"]
metrics_to_log["train/value_loss"] = metadata["value_loss"]
metrics_to_log.update(timeit.todict(prefix="time"))
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()
timeit.erase()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/iql/discrete_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,8 @@ loss:
# IQL specific hyperparameter
temperature: 100
expectile: 0.8

compile:
compile: False
compile_mode: default
cudagraphs: False
90 changes: 52 additions & 38 deletions sota-implementations/iql/iql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
Expand All @@ -34,6 +37,9 @@
)


torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="offline_config")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.backend).set()
Expand Down Expand Up @@ -79,60 +85,69 @@ def main(cfg: "DictConfig"): # noqa: F821
model = make_iql_model(cfg, train_env, eval_env, device)

# Create loss
loss_module, target_net_updater = make_loss(cfg.loss, model)
loss_module, target_net_updater = make_loss(cfg.loss, model, device=device)

# Create optimizer
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
cfg.optim, loss_module
)
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)

gradient_steps = cfg.optim.gradient_steps
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
for i in range(gradient_steps):
pbar.update(1)
# sample data
data = replay_buffer.sample()

if data.device != device:
data = data.to(device, non_blocking=True)

def update(data):
optimizer.zero_grad(set_to_none=True)
# compute losses
loss_info = loss_module(data)
actor_loss = loss_info["loss_actor"]
value_loss = loss_info["loss_value"]
q_loss = loss_info["loss_qvalue"]

optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

optimizer_value.zero_grad()
value_loss.backward()
optimizer_value.step()

optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
(actor_loss + value_loss + q_loss).backward()
optimizer.step()

# update qnet_target params
target_net_updater.step()
return loss_info.detach()

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

pbar = tqdm.tqdm(range(cfg.optim.gradient_steps))

evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
for i in pbar:
# sample data
with timeit("sample"):
data = replay_buffer.sample()
data = data.to(device)

# log metrics
to_log = {
"loss_actor": actor_loss.item(),
"loss_qvalue": q_loss.item(),
"loss_value": value_loss.item(),
}
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_info = update(data)

# evaluation
to_log = loss_info.to_dict()
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
Expand All @@ -147,7 +162,6 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_env.close()
if not train_env.is_closed:
train_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
Loading
Loading