Skip to content

Commit

Permalink
[Feature] CROSSQ compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 5f9e72fe8bb64a2c55647b9927ce6b35d2634c04
Pull Request resolved: #2554
  • Loading branch information
vmoens committed Nov 12, 2024
1 parent 3e387ab commit b9efe65
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 128 deletions.
2 changes: 2 additions & 0 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder

import torch
torch.set_float32_matmul_precision('high')

@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder

import torch
torch.set_float32_matmul_precision('high')

@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down
5 changes: 4 additions & 1 deletion sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
make_offline_replay_buffer,
)

import torch
torch.set_float32_matmul_precision('high')

@hydra.main(config_path="", config_name="offline_config", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -77,7 +79,7 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_env.start()

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model, device=device)

# Create Optimizer
(
Expand Down Expand Up @@ -154,6 +156,7 @@ def update(data, policy_eval_start, iteration):

with timeit("update"):
# compute loss
torch.compiler.cudagraph_mark_step_begin()
i_device = torch.tensor(i, device=device)
loss, loss_vals = update(
data.to(device), policy_eval_start=policy_eval_start, iteration=i_device
Expand Down
4 changes: 3 additions & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
make_environment,
make_replay_buffer,
)
import torch
torch.set_float32_matmul_precision('high')


@hydra.main(version_base="1.1", config_path="", config_name="online_config")
Expand Down Expand Up @@ -103,7 +105,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model, device=device)

# Create optimizer
(
Expand Down
5 changes: 4 additions & 1 deletion sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
make_replay_buffer,
)

import torch
torch.set_float32_matmul_precision('high')

@hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -70,7 +72,7 @@ def main(cfg: "DictConfig"): # noqa: F821
model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device)

# Create loss
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model)
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device)

compile_mode = None
if cfg.loss.compile:
Expand Down Expand Up @@ -170,6 +172,7 @@ def update(sampled_tensordict):
sampled_tensordict = replay_buffer.sample()
sampled_tensordict = sampled_tensordict.to(device)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_dict = update(sampled_tensordict)
tds.append(loss_dict)

Expand Down
12 changes: 6 additions & 6 deletions sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
spec=action_spec,
distribution_class=TanhNormal,
distribution_kwargs={
"low": action_spec.space.low,
"high": action_spec.space.high,
"low": torch.as_tensor(action_spec.space.low, device=device),
"high": torch.as_tensor(action_spec.space.high, device=device),
"tanh_loc": False,
"safe_tanh": not cfg.loss.compile,
},
Expand Down Expand Up @@ -315,7 +315,7 @@ def make_cql_modules_state(model_cfg, proof_environment):
# ---------


def make_continuous_loss(loss_cfg, model):
def make_continuous_loss(loss_cfg, model, device: torch.device|None=None):
loss_module = CQLLoss(
model[0],
model[1],
Expand All @@ -328,19 +328,19 @@ def make_continuous_loss(loss_cfg, model):
with_lagrange=loss_cfg.with_lagrange,
lagrange_thresh=loss_cfg.lagrange_thresh,
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)

return loss_module, target_net_updater


def make_discrete_loss(loss_cfg, model):
def make_discrete_loss(loss_cfg, model, device: torch.device|None=None):
loss_module = DiscreteCQLLoss(
model,
loss_function=loss_cfg.loss_function,
delay_value=True,
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)

return loss_module, target_net_updater
Expand Down
7 changes: 5 additions & 2 deletions sota-implementations/crossq/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ collector:
init_random_frames: 25000
frames_per_batch: 1000
init_env_steps: 1000
device: cpu
device:
env_per_collector: 1
reset_at_each_iter: False

Expand Down Expand Up @@ -46,7 +46,10 @@ network:
actor_activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"
device:
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
Expand Down
Loading

0 comments on commit b9efe65

Please sign in to comment.