Skip to content

Commit

Permalink
[Feature] CROSSQ compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: bd27858d0bd8b1c426ce3c65c9ddbf1d4b2b295c
Pull Request resolved: #2554
  • Loading branch information
vmoens committed Nov 12, 2024
1 parent e834327 commit f960fc7
Show file tree
Hide file tree
Showing 12 changed files with 288 additions and 130 deletions.
10 changes: 10 additions & 0 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings

import hydra

import torch
from tensordict.nn import CudaGraphModule
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder

torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -141,6 +147,10 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
adv_module = torch.compile(adv_module, mode=compile_mode)

if cfg.loss.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
adv_module = CudaGraphModule(adv_module)

Expand Down
12 changes: 10 additions & 2 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings

import hydra

import torch
from tensordict.nn import CudaGraphModule
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder

torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -133,12 +139,14 @@ def update(batch):
compile_mode = "reduce-overhead"

update = torch.compile(update, mode=compile_mode)
actor = torch.compile(actor, mode=compile_mode)
adv_module = torch.compile(adv_module, mode=compile_mode)

if cfg.loss.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10)
actor = CudaGraphModule(actor, warmup=10)
adv_module = CudaGraphModule(adv_module)

# Create collector
Expand Down
13 changes: 12 additions & 1 deletion sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
"""
import time
import warnings

import hydra
import numpy as np

import torch
import tqdm
from tensordict.nn import CudaGraphModule
Expand All @@ -32,6 +34,8 @@
make_offline_replay_buffer,
)

torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="offline_config", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -77,7 +81,9 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_env.start()

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
loss_module, target_net_updater = make_continuous_loss(
cfg.loss, model, device=device
)

# Create Optimizer
(
Expand Down Expand Up @@ -134,6 +140,10 @@ def update(data, policy_eval_start, iteration):
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.loss.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
Expand All @@ -154,6 +164,7 @@ def update(data, policy_eval_start, iteration):

with timeit("update"):
# compute loss
torch.compiler.cudagraph_mark_step_begin()
i_device = torch.tensor(i, device=device)
loss, loss_vals = update(
data.to(device), policy_eval_start=policy_eval_start, iteration=i_device
Expand Down
12 changes: 11 additions & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
The helper functions are coded in the utils.py associated with this script.
"""
import warnings

import hydra
import numpy as np
import torch
Expand All @@ -34,6 +36,8 @@
make_replay_buffer,
)

torch.set_float32_matmul_precision("high")


@hydra.main(version_base="1.1", config_path="", config_name="online_config")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -103,7 +107,9 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
loss_module, target_net_updater = make_continuous_loss(
cfg.loss, model, device=device
)

# Create optimizer
(
Expand Down Expand Up @@ -140,6 +146,10 @@ def update(sampled_tensordict):
if compile_mode:
update = torch.compile(update, mode=compile_mode)
if cfg.loss.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
Expand Down
11 changes: 10 additions & 1 deletion sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
The helper functions are coded in the utils.py associated with this script.
"""
import warnings

import hydra
import numpy as np

import torch
import torch.cuda
import tqdm
Expand All @@ -33,6 +35,8 @@
make_replay_buffer,
)

torch.set_float32_matmul_precision("high")


@hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -70,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821
model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device)

# Create loss
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model)
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device)

compile_mode = None
if cfg.loss.compile:
Expand Down Expand Up @@ -123,6 +127,10 @@ def update(sampled_tensordict):
if compile_mode:
update = torch.compile(update, mode=compile_mode)
if cfg.loss.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
Expand Down Expand Up @@ -170,6 +178,7 @@ def update(sampled_tensordict):
sampled_tensordict = replay_buffer.sample()
sampled_tensordict = sampled_tensordict.to(device)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_dict = update(sampled_tensordict)
tds.append(loss_dict)

Expand Down
12 changes: 6 additions & 6 deletions sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
spec=action_spec,
distribution_class=TanhNormal,
distribution_kwargs={
"low": action_spec.space.low,
"high": action_spec.space.high,
"low": torch.as_tensor(action_spec.space.low, device=device),
"high": torch.as_tensor(action_spec.space.high, device=device),
"tanh_loc": False,
"safe_tanh": not cfg.loss.compile,
},
Expand Down Expand Up @@ -315,7 +315,7 @@ def make_cql_modules_state(model_cfg, proof_environment):
# ---------


def make_continuous_loss(loss_cfg, model):
def make_continuous_loss(loss_cfg, model, device: torch.device | None = None):
loss_module = CQLLoss(
model[0],
model[1],
Expand All @@ -328,19 +328,19 @@ def make_continuous_loss(loss_cfg, model):
with_lagrange=loss_cfg.with_lagrange,
lagrange_thresh=loss_cfg.lagrange_thresh,
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)

return loss_module, target_net_updater


def make_discrete_loss(loss_cfg, model):
def make_discrete_loss(loss_cfg, model, device: torch.device | None = None):
loss_module = DiscreteCQLLoss(
model,
loss_function=loss_cfg.loss_function,
delay_value=True,
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)

return loss_module, target_net_updater
Expand Down
7 changes: 5 additions & 2 deletions sota-implementations/crossq/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ collector:
init_random_frames: 25000
frames_per_batch: 1000
init_env_steps: 1000
device: cpu
device:
env_per_collector: 1
reset_at_each_iter: False

Expand Down Expand Up @@ -46,7 +46,10 @@ network:
actor_activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"
device:
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
Expand Down
Loading

0 comments on commit f960fc7

Please sign in to comment.