Skip to content

Commit

Permalink
[Feature] DT compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 362b6e88bad4397f35036391729e58f4f7e4a25d
Pull Request resolved: #2556
  • Loading branch information
vmoens committed Dec 14, 2024
1 parent 7d7cd95 commit fbfe104
Show file tree
Hide file tree
Showing 18 changed files with 237 additions and 148 deletions.
9 changes: 4 additions & 5 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.nn
import torch.optim
from tensordict.nn import TensorDictModule
from torchrl.data import Composite
from torchrl.data.tensor_specs import CategoricalBox
from torchrl.envs import (
CatFrames,
Expand Down Expand Up @@ -94,12 +93,12 @@ def make_ppo_modules_pixels(proof_environment, device):
input_shape = proof_environment.observation_spec["pixels"].shape

# Define distribution class and kwargs
if isinstance(proof_environment.action_spec.space, CategoricalBox):
num_outputs = proof_environment.action_spec.space.n
if isinstance(proof_environment.single_action_spec.space, CategoricalBox):
num_outputs = proof_environment.single_action_spec.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
else: # is ContinuousBox
num_outputs = proof_environment.action_spec.shape
num_outputs = proof_environment.single_action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low.to(device),
Expand Down Expand Up @@ -153,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment, device):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=Composite(action=proof_environment.action_spec.to(device)),
spec=proof_environment.single_full_action_spec.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
7 changes: 3 additions & 4 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch.optim

from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
from torchrl.data import Composite
from torchrl.envs import (
ClipTransform,
DoubleToFloat,
Expand Down Expand Up @@ -55,7 +54,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
input_shape = proof_environment.observation_spec["observation"].shape

# Define policy output distribution class
num_outputs = proof_environment.action_spec.shape[-1]
num_outputs = proof_environment.single_action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low.to(device),
Expand Down Expand Up @@ -83,7 +82,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
policy_mlp = torch.nn.Sequential(
policy_mlp,
AddStateIndependentNormalScale(
proof_environment.action_spec.shape[-1], device=device
proof_environment.single_action_spec.shape[-1], device=device
),
)

Expand All @@ -95,7 +94,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=Composite(action=proof_environment.action_spec.to(device)),
spec=proof_environment.single_full_action_spec.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"):


def make_cql_modules_state(model_cfg, proof_environment):
action_spec = proof_environment.action_spec
action_spec = proof_environment.single_action_spec

actor_net_kwargs = {
"num_cells": model_cfg.hidden_sizes,
Expand Down
78 changes: 52 additions & 26 deletions sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
This is a self-contained example of an offline Decision Transformer training script.
The helper functions are coded in the utils.py associated with this script.
"""

from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule
from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.libs.gym import set_gym_backend

from torchrl.envs.utils import ExplorationType, set_exploration_type
Expand Down Expand Up @@ -67,58 +70,77 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create policy model
actor = make_dt_model(cfg)
policy = actor.to(model_device)
actor = make_dt_model(cfg, device=model_device)

# Create loss
loss_module = make_dt_loss(cfg.loss, actor)
loss_module = make_dt_loss(cfg.loss, actor, device=model_device)

# Create optimizer
transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module)

# Create inference policy
inference_policy = DecisionTransformerInferenceWrapper(
policy=policy,
policy=actor,
inference_context=cfg.env.inference_context,
).to(model_device)
device=model_device,
)
inference_policy.set_tensor_keys(
observation="observation_cat",
action="action_cat",
return_to_go="return_to_go_cat",
)

pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)

pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
clip_grad = cfg.optim.clip_grad
eval_steps = cfg.logger.eval_steps
pretrain_log_interval = cfg.logger.pretrain_log_interval
reward_scaling = cfg.env.reward_scaling

torchrl_logger.info(" ***Pretraining*** ")
# Pretraining
start_time = time.time()
for i in range(pretrain_gradient_steps):
pbar.update(1)

# Sample data
data = offline_buffer.sample()
def update(data: TensorDict) -> TensorDict:
transformer_optim.zero_grad(set_to_none=True)
# Compute loss
loss_vals = loss_module(data.to(model_device))
loss_vals = loss_module(data)
transformer_loss = loss_vals["loss"]

transformer_optim.zero_grad()
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)
transformer_loss.backward()
torch.nn.utils.clip_grad_norm_(actor.parameters(), clip_grad)
transformer_optim.step()

scheduler.step()
return loss_vals

if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode, dynamic=True)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

eval_steps = cfg.logger.eval_steps
pretrain_log_interval = cfg.logger.pretrain_log_interval
reward_scaling = cfg.env.reward_scaling

torchrl_logger.info(" ***Pretraining*** ")
# Pretraining
pbar = tqdm.tqdm(range(pretrain_gradient_steps))
for i in pbar:
# Sample data
with timeit("rb - sample"):
data = offline_buffer.sample().to(model_device)
with timeit("update"):
loss_vals = update(data)
scheduler.step()
# Log metrics
to_log = {"train/loss": loss_vals["loss"]}

# Evaluation
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
max_steps=eval_steps,
Expand All @@ -129,13 +151,17 @@ def main(cfg: "DictConfig"): # noqa: F821
to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)
if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, to_log, i)

pbar.close()
if not test_env.is_closed:
test_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
7 changes: 6 additions & 1 deletion sota-implementations/decision_transformer/dt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ optim:
# loss
loss:
loss_function: "l2"


compile:
compile: False
compile_mode:
cudagraphs: False

# transformer model
transformer:
n_embd: 128
Expand Down
6 changes: 6 additions & 0 deletions sota-implementations/decision_transformer/odt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ replay_buffer:

# optimizer
optim:
optimizer: lamb
device: null
lr: 1.0e-4
weight_decay: 5.0e-4
Expand All @@ -56,6 +57,11 @@ loss:
alpha_init: 0.1
target_entropy: auto

compile:
compile: False
compile_mode:
cudagraphs: False

# transformer model
transformer:
n_embd: 512
Expand Down
82 changes: 57 additions & 25 deletions sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule
from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.libs.gym import set_gym_backend

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
from torchrl.record import VideoRecorder
Expand Down Expand Up @@ -65,8 +66,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create policy model
actor = make_odt_model(cfg)
policy = actor.to(model_device)
policy = make_odt_model(cfg, device=model_device)

# Create loss
loss_module = make_odt_loss(cfg.loss, policy)
Expand All @@ -80,13 +80,46 @@ def main(cfg: "DictConfig"): # noqa: F821
inference_policy = DecisionTransformerInferenceWrapper(
policy=policy,
inference_context=cfg.env.inference_context,
).to(model_device)
device=model_device,
)
inference_policy.set_tensor_keys(
observation="observation_cat",
action="action_cat",
return_to_go="return_to_go_cat",
)

def update(data):
transformer_optim.zero_grad(set_to_none=True)
temperature_optim.zero_grad(set_to_none=True)
# Compute loss
loss_vals = loss_module(data.to(model_device))
transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"]
temperature_loss = loss_vals["loss_alpha"]

(temperature_loss + transformer_loss).backward()
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)

transformer_optim.step()
temperature_optim.step()

return loss_vals.detach()

if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
compile_mode = "default"
update = torch.compile(update, mode=compile_mode, dynamic=False)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
if cfg.optim.optimizer == "lamb":
raise ValueError(
"cudagraphs isn't compatible with the Lamb optimizer. Use optim.optimizer=Adam instead."
)
update = CudaGraphModule(update, warmup=50)

pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)

pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
Expand All @@ -100,35 +133,29 @@ def main(cfg: "DictConfig"): # noqa: F821
start_time = time.time()
for i in range(pretrain_gradient_steps):
pbar.update(1)
# Sample data
data = offline_buffer.sample()
# Compute loss
loss_vals = loss_module(data.to(model_device))
transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"]
temperature_loss = loss_vals["loss_alpha"]

transformer_optim.zero_grad()
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)
transformer_loss.backward()
transformer_optim.step()
with timeit("sample"):
# Sample data
data = offline_buffer.sample()

temperature_optim.zero_grad()
temperature_loss.backward()
temperature_optim.step()
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_vals = update(data.to(model_device))

scheduler.step()

# Log metrics
to_log = {
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"].item(),
"train/loss_entropy": loss_vals["loss_entropy"].item(),
"train/loss_alpha": loss_vals["loss_alpha"].item(),
"train/alpha": loss_vals["alpha"].item(),
"train/entropy": loss_vals["entropy"].item(),
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"],
"train/loss_entropy": loss_vals["loss_entropy"],
"train/loss_alpha": loss_vals["loss_alpha"],
"train/alpha": loss_vals["alpha"],
"train/entropy": loss_vals["entropy"],
}

# Evaluation
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
with torch.no_grad(), set_exploration_type(
ExplorationType.DETERMINISTIC
), timeit("eval"):
inference_policy.eval()
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
Expand All @@ -143,6 +170,11 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)

if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, to_log, i)

Expand Down
Loading

1 comment on commit fbfe104

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: fbfe104 Previous: 7d7cd95 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] 25.66173680770248 iter/sec (stddev: 0.17008845247611376) 64.76631898263469 iter/sec (stddev: 0.0005560163738160211) 2.52

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.