Skip to content

Commit

Permalink
[Feature] DT compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 7505f08ff48199054ae2be8b63641889c4468836
Pull Request resolved: #2556
  • Loading branch information
vmoens committed Nov 18, 2024
1 parent b12e06a commit 0000d5e
Show file tree
Hide file tree
Showing 22 changed files with 217 additions and 126 deletions.
13 changes: 6 additions & 7 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch.nn
import torch.optim
from tensordict.nn import TensorDictModule
from torchrl.data import Composite
from torchrl.data.tensor_specs import CategoricalBox
from torchrl.envs import (
CatFrames,
Expand Down Expand Up @@ -92,16 +91,16 @@ def make_ppo_modules_pixels(proof_environment, device):
input_shape = proof_environment.observation_spec["pixels"].shape

# Define distribution class and kwargs
if isinstance(proof_environment.action_spec.space, CategoricalBox):
num_outputs = proof_environment.action_spec.space.n
if isinstance(proof_environment.single_action_spec.space, CategoricalBox):
num_outputs = proof_environment.single_action_spec.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
else: # is ContinuousBox
num_outputs = proof_environment.action_spec.shape
num_outputs = proof_environment.single_action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low.to(device),
"high": proof_environment.action_spec.space.high.to(device),
"low": proof_environment.single_action_spec.space.low.to(device),
"high": proof_environment.single_action_spec.space.high.to(device),
}

# Define input keys
Expand Down Expand Up @@ -151,7 +150,7 @@ def make_ppo_modules_pixels(proof_environment, device):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=Composite(action=proof_environment.action_spec.to(device)),
spec=proof_environment.single_full_action_spec.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
11 changes: 5 additions & 6 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.optim

from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
from torchrl.data import Composite
from torchrl.envs import (
ClipTransform,
DoubleToFloat,
Expand Down Expand Up @@ -54,11 +53,11 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
input_shape = proof_environment.observation_spec["observation"].shape

# Define policy output distribution class
num_outputs = proof_environment.action_spec.shape[-1]
num_outputs = proof_environment.single_action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low.to(device),
"high": proof_environment.action_spec.space.high.to(device),
"low": proof_environment.single_action_spec.space.low.to(device),
"high": proof_environment.single_action_spec.space.high.to(device),
"tanh_loc": False,
"safe_tanh": not compile,
}
Expand All @@ -82,7 +81,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
policy_mlp = torch.nn.Sequential(
policy_mlp,
AddStateIndependentNormalScale(
proof_environment.action_spec.shape[-1], device=device
proof_environment.single_action_spec.shape[-1], device=device
),
)

Expand All @@ -94,7 +93,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=Composite(action=proof_environment.action_spec.to(device)),
spec=proof_environment.single_full_action_spec.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"):


def make_cql_modules_state(model_cfg, proof_environment):
action_spec = proof_environment.action_spec
action_spec = proof_environment.single_action_spec

actor_net_kwargs = {
"num_cells": model_cfg.hidden_sizes,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def make_collector(
compile=False,
compile_mode=None,
cudagraph=False,
device: torch.device|None=None,
device: torch.device | None = None,
):
"""Make collector."""
collector = SyncDataCollector(
Expand Down
72 changes: 51 additions & 21 deletions sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
This is a self-contained example of an offline Decision Transformer training script.
The helper functions are coded in the utils.py associated with this script.
"""
import time

import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule
from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.libs.gym import set_gym_backend

from torchrl.envs.utils import ExplorationType, set_exploration_type
Expand Down Expand Up @@ -65,20 +68,20 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create policy model
actor = make_dt_model(cfg)
policy = actor.to(model_device)
actor = make_dt_model(cfg, device=model_device)

# Create loss
loss_module = make_dt_loss(cfg.loss, actor)
loss_module = make_dt_loss(cfg.loss, actor, device=model_device)

# Create optimizer
transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module)

# Create inference policy
inference_policy = DecisionTransformerInferenceWrapper(
policy=policy,
policy=actor,
inference_context=cfg.env.inference_context,
).to(model_device)
device=model_device,
)
inference_policy.set_tensor_keys(
observation="observation_cat",
action="action_cat",
Expand All @@ -89,34 +92,57 @@ def main(cfg: "DictConfig"): # noqa: F821

pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
clip_grad = cfg.optim.clip_grad

def update(data: TensorDict) -> TensorDict:
transformer_optim.zero_grad(set_to_none=True)
# Compute loss
loss_vals = loss_module(data)
transformer_loss = loss_vals["loss"]

torch.nn.utils.clip_grad_norm_(actor.parameters(), clip_grad)
transformer_loss.backward()
transformer_optim.step()

return loss_vals

compile_mode = None
if cfg.loss.compile:
compile_mode = cfg.loss.compile_mode
if compile_mode in ("", None):
if cfg.loss.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.loss.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

eval_steps = cfg.logger.eval_steps
pretrain_log_interval = cfg.logger.pretrain_log_interval
reward_scaling = cfg.env.reward_scaling

torchrl_logger.info(" ***Pretraining*** ")
# Pretraining
start_time = time.time()
for i in range(pretrain_gradient_steps):
pbar.update(1)

# Sample data
data = offline_buffer.sample()
# Compute loss
loss_vals = loss_module(data.to(model_device))
transformer_loss = loss_vals["loss"]

transformer_optim.zero_grad()
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)
transformer_loss.backward()
transformer_optim.step()

with timeit("rb - sample"):
data = offline_buffer.sample().to(model_device)
with timeit("update"):
loss_vals = update(data)
scheduler.step()

# Log metrics
to_log = {"train/loss": loss_vals["loss"]}

# Evaluation
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
max_steps=eval_steps,
Expand All @@ -127,13 +153,17 @@ def main(cfg: "DictConfig"): # noqa: F821
to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)
if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, to_log, i)

pbar.close()
if not test_env.is_closed:
test_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion sota-implementations/decision_transformer/dt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ optim:
# loss
loss:
loss_function: "l2"

compile: False
compile_mode:
cudagraphs: False

# transformer model
transformer:
n_embd: 128
Expand Down
3 changes: 3 additions & 0 deletions sota-implementations/decision_transformer/odt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ optim:
loss:
alpha_init: 0.1
target_entropy: auto
compile: False
compile_mode:
cudagraphs: False

# transformer model
transformer:
Expand Down
81 changes: 56 additions & 25 deletions sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
The helper functions are coded in the utils.py associated with this script.
"""
import time
import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule
from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.libs.gym import set_gym_backend

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
from torchrl.record import VideoRecorder
Expand Down Expand Up @@ -63,8 +64,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create policy model
actor = make_odt_model(cfg)
policy = actor.to(model_device)
policy = make_odt_model(cfg, device=model_device)

# Create loss
loss_module = make_odt_loss(cfg.loss, policy)
Expand All @@ -78,13 +78,46 @@ def main(cfg: "DictConfig"): # noqa: F821
inference_policy = DecisionTransformerInferenceWrapper(
policy=policy,
inference_context=cfg.env.inference_context,
).to(model_device)
device=model_device,
)
inference_policy.set_tensor_keys(
observation="observation_cat",
action="action_cat",
return_to_go="return_to_go_cat",
)

def update(data):
transformer_optim.zero_grad(set_to_none=True)
temperature_optim.zero_grad(set_to_none=True)
# Compute loss
loss_vals = loss_module(data.to(model_device))
transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"]
temperature_loss = loss_vals["loss_alpha"]

(temperature_loss + transformer_loss).backward()
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)

transformer_optim.step()
temperature_optim.step()

return loss_vals.detach()

compile_mode = None
if cfg.loss.compile:
compile_mode = cfg.loss.compile_mode
if compile_mode in ("", None):
if cfg.loss.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.loss.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)

pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
Expand All @@ -98,35 +131,28 @@ def main(cfg: "DictConfig"): # noqa: F821
start_time = time.time()
for i in range(pretrain_gradient_steps):
pbar.update(1)
# Sample data
data = offline_buffer.sample()
# Compute loss
loss_vals = loss_module(data.to(model_device))
transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"]
temperature_loss = loss_vals["loss_alpha"]
with timeit("sample"):
# Sample data
data = offline_buffer.sample()

transformer_optim.zero_grad()
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)
transformer_loss.backward()
transformer_optim.step()

temperature_optim.zero_grad()
temperature_loss.backward()
temperature_optim.step()
with timeit("update"):
loss_vals = update(data.to(model_device))

scheduler.step()

# Log metrics
to_log = {
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"].item(),
"train/loss_entropy": loss_vals["loss_entropy"].item(),
"train/loss_alpha": loss_vals["loss_alpha"].item(),
"train/alpha": loss_vals["alpha"].item(),
"train/entropy": loss_vals["entropy"].item(),
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"],
"train/loss_entropy": loss_vals["loss_entropy"],
"train/loss_alpha": loss_vals["loss_alpha"],
"train/alpha": loss_vals["alpha"],
"train/entropy": loss_vals["entropy"],
}

# Evaluation
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
with torch.no_grad(), set_exploration_type(
ExplorationType.DETERMINISTIC
), timeit("eval"):
inference_policy.eval()
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
Expand All @@ -141,6 +167,11 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)

if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, to_log, i)

Expand Down
Loading

0 comments on commit 0000d5e

Please sign in to comment.