From 0000d5e75f08ebdd6a59ee8665b2597948ffbb39 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 09:42:57 +0000 Subject: [PATCH] [Feature] DT compatibility with compile ghstack-source-id: 7505f08ff48199054ae2be8b63641889c4468836 Pull Request resolved: https://github.com/pytorch/rl/pull/2556 --- sota-implementations/a2c/utils_atari.py | 13 ++- sota-implementations/a2c/utils_mujoco.py | 11 ++- sota-implementations/cql/utils.py | 2 +- sota-implementations/ddpg/utils.py | 2 +- .../decision_transformer/dt.py | 72 ++++++++++++----- .../decision_transformer/dt_config.yaml | 5 +- .../decision_transformer/odt_config.yaml | 3 + .../decision_transformer/online_dt.py | 81 +++++++++++++------ .../decision_transformer/utils.py | 47 +++++++---- sota-implementations/dreamer/dreamer_utils.py | 16 ++-- sota-implementations/gail/ppo_utils.py | 11 ++- sota-implementations/impala/utils.py | 5 +- sota-implementations/iql/utils.py | 2 +- sota-implementations/ppo/utils_atari.py | 13 ++- sota-implementations/ppo/utils_mujoco.py | 11 ++- sota-implementations/redq/utils.py | 2 +- torchrl/collectors/collectors.py | 3 + torchrl/data/tensor_specs.py | 6 +- torchrl/modules/distributions/continuous.py | 21 +++-- .../modules/models/decision_transformer.py | 5 ++ torchrl/modules/tensordict_module/actors.py | 4 + torchrl/objectives/decision_transformer.py | 8 +- 22 files changed, 217 insertions(+), 126 deletions(-) diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index bf7e23cd8f9..8b5c1e0bbda 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -7,7 +7,6 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import Composite from torchrl.data.tensor_specs import CategoricalBox from torchrl.envs import ( CatFrames, @@ -92,16 +91,16 @@ def make_ppo_modules_pixels(proof_environment, device): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, CategoricalBox): - num_outputs = proof_environment.action_spec.space.n + if isinstance(proof_environment.single_action_spec.space, CategoricalBox): + num_outputs = proof_environment.single_action_spec.space.n distribution_class = OneHotCategorical distribution_kwargs = {} else: # is ContinuousBox - num_outputs = proof_environment.action_spec.shape + num_outputs = proof_environment.single_action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low.to(device), - "high": proof_environment.action_spec.space.high.to(device), + "low": proof_environment.single_action_spec.space.low.to(device), + "high": proof_environment.single_action_spec.space.high.to(device), } # Define input keys @@ -151,7 +150,7 @@ def make_ppo_modules_pixels(proof_environment, device): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=Composite(action=proof_environment.action_spec.to(device)), + spec=proof_environment.single_full_action_spec.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index e16bcefc890..ad094e75f52 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -8,7 +8,6 @@ import torch.optim from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule -from torchrl.data import Composite from torchrl.envs import ( ClipTransform, DoubleToFloat, @@ -54,11 +53,11 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): input_shape = proof_environment.observation_spec["observation"].shape # Define policy output distribution class - num_outputs = proof_environment.action_spec.shape[-1] + num_outputs = proof_environment.single_action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low.to(device), - "high": proof_environment.action_spec.space.high.to(device), + "low": proof_environment.single_action_spec.space.low.to(device), + "high": proof_environment.single_action_spec.space.high.to(device), "tanh_loc": False, "safe_tanh": not compile, } @@ -82,7 +81,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.action_spec.shape[-1], device=device + proof_environment.single_action_spec.shape[-1], device=device ), ) @@ -94,7 +93,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=Composite(action=proof_environment.action_spec.to(device)), + spec=proof_environment.single_full_action_spec.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 0cedfdb07a9..f0ccfb6f0c1 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -285,7 +285,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"): def make_cql_modules_state(model_cfg, proof_environment): - action_spec = proof_environment.action_spec + action_spec = proof_environment.single_action_spec actor_net_kwargs = { "num_cells": model_cfg.hidden_sizes, diff --git a/sota-implementations/ddpg/utils.py b/sota-implementations/ddpg/utils.py index b94dc64ecfb..7bf440c5152 100644 --- a/sota-implementations/ddpg/utils.py +++ b/sota-implementations/ddpg/utils.py @@ -118,7 +118,7 @@ def make_collector( compile=False, compile_mode=None, cudagraph=False, - device: torch.device|None=None, + device: torch.device | None = None, ): """Make collector.""" collector = SyncDataCollector( diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index b892462339c..1b3a069df0f 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -6,13 +6,16 @@ This is a self-contained example of an offline Decision Transformer training script. The helper functions are coded in the utils.py associated with this script. """ -import time + +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule +from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.libs.gym import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -65,20 +68,20 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create policy model - actor = make_dt_model(cfg) - policy = actor.to(model_device) + actor = make_dt_model(cfg, device=model_device) # Create loss - loss_module = make_dt_loss(cfg.loss, actor) + loss_module = make_dt_loss(cfg.loss, actor, device=model_device) # Create optimizer transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module) # Create inference policy inference_policy = DecisionTransformerInferenceWrapper( - policy=policy, + policy=actor, inference_context=cfg.env.inference_context, - ).to(model_device) + device=model_device, + ) inference_policy.set_tensor_keys( observation="observation_cat", action="action_cat", @@ -89,34 +92,57 @@ def main(cfg: "DictConfig"): # noqa: F821 pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps clip_grad = cfg.optim.clip_grad + + def update(data: TensorDict) -> TensorDict: + transformer_optim.zero_grad(set_to_none=True) + # Compute loss + loss_vals = loss_module(data) + transformer_loss = loss_vals["loss"] + + torch.nn.utils.clip_grad_norm_(actor.parameters(), clip_grad) + transformer_loss.backward() + transformer_optim.step() + + return loss_vals + + compile_mode = None + if cfg.loss.compile: + compile_mode = cfg.loss.compile_mode + if compile_mode in ("", None): + if cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.loss.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + eval_steps = cfg.logger.eval_steps pretrain_log_interval = cfg.logger.pretrain_log_interval reward_scaling = cfg.env.reward_scaling torchrl_logger.info(" ***Pretraining*** ") # Pretraining - start_time = time.time() for i in range(pretrain_gradient_steps): pbar.update(1) # Sample data - data = offline_buffer.sample() - # Compute loss - loss_vals = loss_module(data.to(model_device)) - transformer_loss = loss_vals["loss"] - - transformer_optim.zero_grad() - torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad) - transformer_loss.backward() - transformer_optim.step() - + with timeit("rb - sample"): + data = offline_buffer.sample().to(model_device) + with timeit("update"): + loss_vals = update(data) scheduler.step() - # Log metrics to_log = {"train/loss": loss_vals["loss"]} # Evaluation - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): if i % pretrain_log_interval == 0: eval_td = test_env.rollout( max_steps=eval_steps, @@ -127,13 +153,17 @@ def main(cfg: "DictConfig"): # noqa: F821 to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) + if i % 200 == 0: + to_log.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() + if logger is not None: log_metrics(logger, to_log, i) pbar.close() if not test_env.is_closed: test_env.close() - torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/sota-implementations/decision_transformer/dt_config.yaml b/sota-implementations/decision_transformer/dt_config.yaml index 4805785a62c..3fb2529a7b7 100644 --- a/sota-implementations/decision_transformer/dt_config.yaml +++ b/sota-implementations/decision_transformer/dt_config.yaml @@ -55,7 +55,10 @@ optim: # loss loss: loss_function: "l2" - + compile: False + compile_mode: + cudagraphs: False + # transformer model transformer: n_embd: 128 diff --git a/sota-implementations/decision_transformer/odt_config.yaml b/sota-implementations/decision_transformer/odt_config.yaml index eec2b455fb3..479b4768859 100644 --- a/sota-implementations/decision_transformer/odt_config.yaml +++ b/sota-implementations/decision_transformer/odt_config.yaml @@ -55,6 +55,9 @@ optim: loss: alpha_init: 0.1 target_entropy: auto + compile: False + compile_mode: + cudagraphs: False # transformer model transformer: diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 184c850b626..0d03493207e 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -7,14 +7,15 @@ The helper functions are coded in the utils.py associated with this script. """ import time +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule +from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.libs.gym import set_gym_backend - from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from torchrl.record import VideoRecorder @@ -63,8 +64,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create policy model - actor = make_odt_model(cfg) - policy = actor.to(model_device) + policy = make_odt_model(cfg, device=model_device) # Create loss loss_module = make_odt_loss(cfg.loss, policy) @@ -78,13 +78,46 @@ def main(cfg: "DictConfig"): # noqa: F821 inference_policy = DecisionTransformerInferenceWrapper( policy=policy, inference_context=cfg.env.inference_context, - ).to(model_device) + device=model_device, + ) inference_policy.set_tensor_keys( observation="observation_cat", action="action_cat", return_to_go="return_to_go_cat", ) + def update(data): + transformer_optim.zero_grad(set_to_none=True) + temperature_optim.zero_grad(set_to_none=True) + # Compute loss + loss_vals = loss_module(data.to(model_device)) + transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"] + temperature_loss = loss_vals["loss_alpha"] + + (temperature_loss + transformer_loss).backward() + torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad) + + transformer_optim.step() + temperature_optim.step() + + return loss_vals.detach() + + compile_mode = None + if cfg.loss.compile: + compile_mode = cfg.loss.compile_mode + if compile_mode in ("", None): + if cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.loss.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps @@ -98,35 +131,28 @@ def main(cfg: "DictConfig"): # noqa: F821 start_time = time.time() for i in range(pretrain_gradient_steps): pbar.update(1) - # Sample data - data = offline_buffer.sample() - # Compute loss - loss_vals = loss_module(data.to(model_device)) - transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"] - temperature_loss = loss_vals["loss_alpha"] + with timeit("sample"): + # Sample data + data = offline_buffer.sample() - transformer_optim.zero_grad() - torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad) - transformer_loss.backward() - transformer_optim.step() - - temperature_optim.zero_grad() - temperature_loss.backward() - temperature_optim.step() + with timeit("update"): + loss_vals = update(data.to(model_device)) scheduler.step() # Log metrics to_log = { - "train/loss_log_likelihood": loss_vals["loss_log_likelihood"].item(), - "train/loss_entropy": loss_vals["loss_entropy"].item(), - "train/loss_alpha": loss_vals["loss_alpha"].item(), - "train/alpha": loss_vals["alpha"].item(), - "train/entropy": loss_vals["entropy"].item(), + "train/loss_log_likelihood": loss_vals["loss_log_likelihood"], + "train/loss_entropy": loss_vals["loss_entropy"], + "train/loss_alpha": loss_vals["loss_alpha"], + "train/alpha": loss_vals["alpha"], + "train/entropy": loss_vals["entropy"], } # Evaluation - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): inference_policy.eval() if i % pretrain_log_interval == 0: eval_td = test_env.rollout( @@ -141,6 +167,11 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) + if i % 200 == 0: + to_log.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() + if logger is not None: log_metrics(logger, to_log, i) diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index ee2cc6e424c..721c3030a51 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import os +from pathlib import Path import torch.nn @@ -155,6 +157,7 @@ def make_env(): obs_std, train, ) + env.start() return env @@ -261,6 +264,7 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): direct_download=True, prefetch=4, writer=RoundRobinWriter(), + root=Path(os.environ["HOME"]) / ".cache" / "torchrl" / "data" / "d4rl", ) # since we're not extending the data, adding keys can only be done via @@ -334,14 +338,14 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): # ----- -def make_odt_model(cfg): +def make_odt_model(cfg, device: torch.device | None = None) -> TensorDictModule: env_cfg = cfg.env proof_environment = make_transformed_env( make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1 ) - action_spec = proof_environment.action_spec - for key, value in proof_environment.observation_spec.items(): + action_spec = proof_environment.single_action_spec + for key, value in proof_environment.single_observation_spec.items(): if key == "observation": state_dim = value.shape[-1] in_keys = [ @@ -354,6 +358,7 @@ def make_odt_model(cfg): state_dim=state_dim, action_dim=action_spec.shape[-1], transformer_config=cfg.transformer, + device=device, ) actor_module = TensorDictModule( @@ -365,7 +370,13 @@ def make_odt_model(cfg): ], ) dist_class = TanhNormal - dist_kwargs = {"low": -1.0, "high": 1.0, "tanh_loc": False, "upscale": 5.0} + dist_kwargs = { + "low": -1.0, + "high": 1.0, + "tanh_loc": False, + "upscale": 5.0, + "safe_tanh": not cfg.loss.compile, + } actor = ProbabilisticActor( spec=action_spec, @@ -387,16 +398,14 @@ def make_odt_model(cfg): return actor -def make_dt_model(cfg): +def make_dt_model(cfg, device: torch.device | None = None): env_cfg = cfg.env proof_environment = make_transformed_env( make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1 ) - action_spec = proof_environment.action_spec - for key, value in proof_environment.observation_spec.items(): - if key == "observation": - state_dim = value.shape[-1] + action_spec = proof_environment.single_action_spec + obs_spec = proof_environment.single_observation_spec in_keys = [ "observation_cat", "action_cat", @@ -404,9 +413,10 @@ def make_dt_model(cfg): ] actor_net = DTActor( - state_dim=state_dim, + state_dim=obs_spec["observation"].shape[-1], action_dim=action_spec.shape[-1], transformer_config=cfg.transformer, + device=device, ) actor_module = TensorDictModule( @@ -418,10 +428,11 @@ def make_dt_model(cfg): dist_kwargs = { "low": action_spec.space.low, "high": action_spec.space.high, + "safe": not cfg.loss.compile, } actor = ProbabilisticActor( - spec=action_spec, + spec=action_spec.to(device), in_keys=["param"], out_keys=["action"], module=actor_module, @@ -433,9 +444,10 @@ def make_dt_model(cfg): # init the lazy layers with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = proof_environment.rollout(max_steps=100) + td = proof_environment.fake_tensordict() + td = td.expand((100, *td.shape)) td["action"] = td["next", "action"] - actor(td) + actor(td.to(device)) return actor @@ -455,10 +467,11 @@ def make_odt_loss(loss_cfg, actor_network): return loss -def make_dt_loss(loss_cfg, actor_network): +def make_dt_loss(loss_cfg, actor_network, device: torch.device | None = None): loss = DTLoss( actor_network, loss_function=loss_cfg.loss_function, + device=device, ) loss.set_keys(action_target="action_cat") return loss @@ -467,7 +480,7 @@ def make_dt_loss(loss_cfg, actor_network): def make_odt_optimizer(optim_cfg, loss_module): dt_optimizer = Lamb( loss_module.actor_network_params.flatten_keys().values(), - lr=optim_cfg.lr, + lr=torch.as_tensor(optim_cfg.lr, device=next(loss_module.parameters()).device), weight_decay=optim_cfg.weight_decay, eps=1.0e-8, ) @@ -477,7 +490,7 @@ def make_odt_optimizer(optim_cfg, loss_module): log_temp_optimizer = torch.optim.Adam( [loss_module.log_alpha], - lr=1e-4, + lr=torch.as_tensor(1e-4, device=next(loss_module.parameters()).device), betas=[0.9, 0.999], ) @@ -487,7 +500,7 @@ def make_odt_optimizer(optim_cfg, loss_module): def make_dt_optimizer(optim_cfg, loss_module): dt_optimizer = torch.optim.Adam( loss_module.actor_network_params.flatten_keys().values(), - lr=optim_cfg.lr, + lr=torch.as_tensor(optim_cfg.lr), weight_decay=optim_cfg.weight_decay, eps=1.0e-8, ) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 849d8c813b6..ceecc7956e9 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -472,12 +472,12 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): spec=Composite( **{ "loc": Unbounded( - proof_environment.action_spec.shape, - device=proof_environment.action_spec.device, + proof_environment.single_action_spec.shape, + device=proof_environment.single_action_spec.device, ), "scale": Unbounded( - proof_environment.action_spec.shape, - device=proof_environment.action_spec.device, + proof_environment.single_action_spec.shape, + device=proof_environment.single_action_spec.device, ), } ), @@ -488,7 +488,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): default_interaction_type=InteractionType.RANDOM, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=Composite(**{action_key: proof_environment.action_spec}), + spec=Composite(**{action_key: proof_environment.single_action_spec}), ), ) return actor_simulator @@ -529,10 +529,10 @@ def _dreamer_make_actor_real( spec=Composite( **{ "loc": Unbounded( - proof_environment.action_spec.shape, + proof_environment.single_action_spec.shape, ), "scale": Unbounded( - proof_environment.action_spec.shape, + proof_environment.single_action_spec.shape, ), } ), @@ -543,7 +543,7 @@ def _dreamer_make_actor_real( default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=Composite(**{action_key: proof_environment.action_spec.to("cpu")}), + spec=proof_environment.single_full_action_spec.to("cpu"), ), ), SafeModule( diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index 7986738f8e6..635c24517e6 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -7,7 +7,6 @@ import torch.optim from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule -from torchrl.data import CompositeSpec from torchrl.envs import ( ClipTransform, DoubleToFloat, @@ -49,11 +48,11 @@ def make_ppo_models_state(proof_environment): input_shape = proof_environment.observation_spec["observation"].shape # Define policy output distribution class - num_outputs = proof_environment.action_spec.shape[-1] + num_outputs = proof_environment.single_action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.single_action_spec.space.low, + "high": proof_environment.single_action_spec.space.high, "tanh_loc": False, } @@ -75,7 +74,7 @@ def make_ppo_models_state(proof_environment): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.action_spec.shape[-1], scale_lb=1e-8 + proof_environment.single_action_spec.shape[-1], scale_lb=1e-8 ), ) @@ -87,7 +86,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=proof_environment.single_full_action_spec, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/impala/utils.py b/sota-implementations/impala/utils.py index 30293940377..7ed16313176 100644 --- a/sota-implementations/impala/utils.py +++ b/sota-implementations/impala/utils.py @@ -6,7 +6,6 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import Composite from torchrl.envs import ( CatFrames, DoubleToFloat, @@ -69,7 +68,7 @@ def make_ppo_modules_pixels(proof_environment): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - num_outputs = proof_environment.action_spec.space.n + num_outputs = proof_environment.single_action_spec.space.n distribution_class = OneHotCategorical distribution_kwargs = {} @@ -117,7 +116,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=Composite(action=proof_environment.action_spec), + spec=proof_environment.single_full_action_spec, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index a24c6168375..277b83a4ef3 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -249,7 +249,7 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): def make_iql_modules_state(model_cfg, proof_environment): - action_spec = proof_environment.action_spec + action_spec = proof_environment.single_action_spec actor_net_kwargs = { "num_cells": model_cfg.hidden_sizes, diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 50f91ed49cd..1199c7dd368 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -6,7 +6,6 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import Composite from torchrl.data.tensor_specs import CategoricalBox from torchrl.envs import ( CatFrames, @@ -92,16 +91,16 @@ def make_ppo_modules_pixels(proof_environment): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, CategoricalBox): - num_outputs = proof_environment.action_spec.space.n + if isinstance(proof_environment.single_action_spec.space, CategoricalBox): + num_outputs = proof_environment.single_action_spec.space.n distribution_class = OneHotCategorical distribution_kwargs = {} else: # is ContinuousBox - num_outputs = proof_environment.action_spec.shape + num_outputs = proof_environment.single_action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.single_action_spec.space.low, + "high": proof_environment.single_action_spec.space.high, } # Define input keys @@ -148,7 +147,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=Composite(action=proof_environment.action_spec), + spec=proof_environment.single_full_action_spec, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index a05d205b000..635c24517e6 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -7,7 +7,6 @@ import torch.optim from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule -from torchrl.data import Composite from torchrl.envs import ( ClipTransform, DoubleToFloat, @@ -49,11 +48,11 @@ def make_ppo_models_state(proof_environment): input_shape = proof_environment.observation_spec["observation"].shape # Define policy output distribution class - num_outputs = proof_environment.action_spec.shape[-1] + num_outputs = proof_environment.single_action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.single_action_spec.space.low, + "high": proof_environment.single_action_spec.space.high, "tanh_loc": False, } @@ -75,7 +74,7 @@ def make_ppo_models_state(proof_environment): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.action_spec.shape[-1], scale_lb=1e-8 + proof_environment.single_action_spec.shape[-1], scale_lb=1e-8 ), ) @@ -87,7 +86,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=Composite(action=proof_environment.action_spec), + spec=proof_environment.single_full_action_spec, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 8312d359366..a35d456a2db 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -406,7 +406,7 @@ def make_redq_model( default_policy_scale = cfg.network.default_policy_scale gSDE = cfg.exploration.gSDE - action_spec = proof_environment.action_spec + action_spec = proof_environment.single_action_spec if actor_net_kwargs is None: actor_net_kwargs = {} diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index bf1192fc490..d805307bb3f 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -66,13 +66,16 @@ RandomPolicy, set_exploration_type, ) + try: from torch.compiler import cudagraph_mark_step_begin except ImportError: + def cudagraph_mark_step_begin(): """Placeholder when cudagraph_mark_step_begin is missing.""" ... + _TIMEOUT = 1.0 INSTANTIATE_TIMEOUT = 20 _MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 860de4b6793..3e3fb9daee7 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2315,10 +2315,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self - self.space.device = dest_device + space = self.space.to(dest_device) return Bounded( - low=self.space.low, - high=self.space.high, + low=space.low, + high=space.high, shape=self.shape, device=dest_device, dtype=dest_dtype, diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 32862ffe1c3..fa7e8ee7ca1 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -393,7 +393,6 @@ def __init__( event_dims: int | None = None, tanh_loc: bool = False, safe_tanh: bool = True, - **kwargs, ): if not isinstance(loc, torch.Tensor): loc = torch.as_tensor(loc, dtype=torch.get_default_dtype()) @@ -679,6 +678,7 @@ def __init__( event_dims: int = 1, atol: float = 1e-6, rtol: float = 1e-6, + safe: bool = True, ): minmax_msg = "high value has been found to be equal or less than low value" if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): @@ -691,12 +691,19 @@ def __init__( if not all(high > low): raise ValueError(minmax_msg) - t = SafeTanhTransform() - non_trivial_min = (isinstance(low, torch.Tensor) and (low != -1.0).any()) or ( - not isinstance(low, torch.Tensor) and low != -1.0 + if safe: + if is_dynamo_compiling(): + _err_compile_safetanh() + t = SafeTanhTransform() + else: + t = torch.distributions.TanhTransform() + non_trivial_min = is_dynamo_compiling or ( + (isinstance(low, torch.Tensor) and (low != -1.0).any()) + or (not isinstance(low, torch.Tensor) and low != -1.0) ) - non_trivial_max = (isinstance(high, torch.Tensor) and (high != 1.0).any()) or ( - not isinstance(high, torch.Tensor) and high != 1.0 + non_trivial_max = is_dynamo_compiling or ( + (isinstance(high, torch.Tensor) and (high != 1.0).any()) + or (not isinstance(high, torch.Tensor) and high != 1.0) ) self.non_trivial = non_trivial_min or non_trivial_max @@ -773,7 +780,7 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: def _err_compile_safetanh(): raise RuntimeError( - "safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass" + "safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass " "safe_tanh=False. " "If you are using a ProbabilisticTensorDictModule, this can be done via " "`distribution_kwargs={'safe_tanh': False}`. " diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 8eb72f1f9ea..8a20ad2eba8 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -90,7 +90,12 @@ def __init__( state_dim, action_dim, config: dict | DTConfig = None, + device: torch.device | None = None, ): + if device is not None: + with torch.device(device): + return self.__init__(state_dim, action_dim, config) + if not _has_transformers: raise ImportError( "transformers is not installed. Please install it with `pip install transformers`." diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index bf945e609cb..60d82556829 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1783,6 +1783,7 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): For example for an observation input of shape [batch_size, context, obs_dim] with context=20 and inference_context=5, the first 15 entries of the context will be masked. Defaults to 5. spec (Optional[TensorSpec]): The spec of the input TensorDict. If None, it will be inferred from the policy module. + device (torch.device, optional): if provided, the device where the buffers / specs will be placed. Examples: >>> import torch @@ -1836,6 +1837,7 @@ def __init__( *, inference_context: int = 5, spec: Optional[TensorSpec] = None, + device: torch.device | None = None, ): super().__init__(policy) self.observation_key = "observation" @@ -1857,6 +1859,8 @@ def __init__( self._spec[self.action_key] = None else: self._spec = Composite({key: None for key in policy.out_keys}) + if device is not None: + self._spec = self._spec.to(device) self.checked = False @property diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index eb34b021484..b6783116ba1 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -292,6 +292,7 @@ def __init__( *, loss_function: str = "l2", reduction: str = None, + device: torch.device | None = None, ) -> None: self._in_keys = None self._out_keys = None @@ -343,7 +344,7 @@ def out_keys(self, values): def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets - tensordict = tensordict.clone(False) + tensordict = tensordict.copy() target_actions = tensordict.get(self.tensor_keys.action_target).detach() with self.actor_network_params.to_module(self.actor_network): @@ -356,8 +357,5 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_function=self.loss_function, ) loss = _reduce(loss, reduction=self.reduction) - out = { - "loss": loss, - } - td_out = TensorDict(out, []) + td_out = TensorDict(loss=loss) return td_out