From b9efe6527f4d68768c653f3865a36205053a14ba Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 12 Nov 2024 10:18:26 +0000 Subject: [PATCH] [Feature] CROSSQ compatibility with compile ghstack-source-id: 5f9e72fe8bb64a2c55647b9927ce6b35d2634c04 Pull Request resolved: https://github.com/pytorch/rl/pull/2554 --- sota-implementations/a2c/a2c_atari.py | 2 + sota-implementations/a2c/a2c_mujoco.py | 2 + sota-implementations/cql/cql_offline.py | 5 +- sota-implementations/cql/cql_online.py | 4 +- .../cql/discrete_cql_online.py | 5 +- sota-implementations/cql/utils.py | 12 +- sota-implementations/crossq/config.yaml | 7 +- sota-implementations/crossq/crossq.py | 195 +++++++++++------- sota-implementations/crossq/utils.py | 25 ++- torchrl/objectives/common.py | 18 +- torchrl/objectives/crossq.py | 14 +- torchrl/objectives/value/advantages.py | 76 +++++-- 12 files changed, 237 insertions(+), 128 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 09919eb7dd0..80a9c0efc30 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -7,6 +7,8 @@ from torchrl._utils import logger as torchrl_logger from torchrl.record import VideoRecorder +import torch +torch.set_float32_matmul_precision('high') @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 1160626ce8e..bf4ab273318 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -7,6 +7,8 @@ from torchrl._utils import logger as torchrl_logger from torchrl.record import VideoRecorder +import torch +torch.set_float32_matmul_precision('high') @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 410ef1dd973..e1dfa8cedba 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -32,6 +32,8 @@ make_offline_replay_buffer, ) +import torch +torch.set_float32_matmul_precision('high') @hydra.main(config_path="", config_name="offline_config", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 @@ -77,7 +79,7 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_env.start() # Create loss - loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) + loss_module, target_net_updater = make_continuous_loss(cfg.loss, model, device=device) # Create Optimizer ( @@ -154,6 +156,7 @@ def update(data, policy_eval_start, iteration): with timeit("update"): # compute loss + torch.compiler.cudagraph_mark_step_begin() i_device = torch.tensor(i, device=device) loss, loss_vals = update( data.to(device), policy_eval_start=policy_eval_start, iteration=i_device diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index c95b1af708b..a9f1493196f 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -33,6 +33,8 @@ make_environment, make_replay_buffer, ) +import torch +torch.set_float32_matmul_precision('high') @hydra.main(version_base="1.1", config_path="", config_name="online_config") @@ -103,7 +105,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create loss - loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) + loss_module, target_net_updater = make_continuous_loss(cfg.loss, model, device=device) # Create optimizer ( diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index 8d08b180175..b5df0a82517 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -33,6 +33,8 @@ make_replay_buffer, ) +import torch +torch.set_float32_matmul_precision('high') @hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config") def main(cfg: "DictConfig"): # noqa: F821 @@ -70,7 +72,7 @@ def main(cfg: "DictConfig"): # noqa: F821 model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device) # Create loss - loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) + loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device) compile_mode = None if cfg.loss.compile: @@ -170,6 +172,7 @@ def update(sampled_tensordict): sampled_tensordict = replay_buffer.sample() sampled_tensordict = sampled_tensordict.to(device) with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() loss_dict = update(sampled_tensordict) tds.append(loss_dict) diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 00f0a81c515..d2a50623d78 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -217,8 +217,8 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"): spec=action_spec, distribution_class=TanhNormal, distribution_kwargs={ - "low": action_spec.space.low, - "high": action_spec.space.high, + "low": torch.as_tensor(action_spec.space.low, device=device), + "high": torch.as_tensor(action_spec.space.high, device=device), "tanh_loc": False, "safe_tanh": not cfg.loss.compile, }, @@ -315,7 +315,7 @@ def make_cql_modules_state(model_cfg, proof_environment): # --------- -def make_continuous_loss(loss_cfg, model): +def make_continuous_loss(loss_cfg, model, device: torch.device|None=None): loss_module = CQLLoss( model[0], model[1], @@ -328,19 +328,19 @@ def make_continuous_loss(loss_cfg, model): with_lagrange=loss_cfg.with_lagrange, lagrange_thresh=loss_cfg.lagrange_thresh, ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) + loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device) target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) return loss_module, target_net_updater -def make_discrete_loss(loss_cfg, model): +def make_discrete_loss(loss_cfg, model, device: torch.device|None=None): loss_module = DiscreteCQLLoss( model, loss_function=loss_cfg.loss_function, delay_value=True, ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) + loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device) target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) return loss_module, target_net_updater diff --git a/sota-implementations/crossq/config.yaml b/sota-implementations/crossq/config.yaml index 1dcbd3db92d..54066a9338a 100644 --- a/sota-implementations/crossq/config.yaml +++ b/sota-implementations/crossq/config.yaml @@ -12,7 +12,7 @@ collector: init_random_frames: 25000 frames_per_batch: 1000 init_env_steps: 1000 - device: cpu + device: env_per_collector: 1 reset_at_each_iter: False @@ -46,7 +46,10 @@ network: actor_activation: relu default_policy_scale: 1.0 scale_lb: 0.1 - device: "cuda:0" + device: + compile: False + compile_mode: + cudagraphs: False # logging logger: diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index b07ae880046..6babfe80cd2 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -10,7 +10,6 @@ The helper functions are coded in the utils.py associated with this script. """ -import time import hydra @@ -18,7 +17,10 @@ import torch import torch.cuda import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -32,6 +34,8 @@ make_replay_buffer, ) +import torch +torch.set_float32_matmul_precision('high') @hydra.main(version_base="1.1", config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 @@ -69,10 +73,27 @@ def main(cfg: "DictConfig"): # noqa: F821 model, exploration_policy = make_crossQ_agent(cfg, train_env, device) # Create CrossQ loss - loss_module = make_loss_module(cfg, model) + loss_module = make_loss_module(cfg, model, device=device) + + compile_mode = None + if cfg.network.compile: + if cfg.network.compile_mode not in (None, ""): + compile_mode = cfg.network.compile_mode + elif cfg.network.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" # Create off-policy collector - collector = make_collector(cfg, train_env, exploration_policy.eval(), device=device) + collector = make_collector( + cfg, + train_env, + exploration_policy.eval(), + device=device, + compile=cfg.network.compile, + compile_mode=compile_mode, + cudagraph=cfg.network.cudagraphs, + ) # Create replay buffer replay_buffer = make_replay_buffer( @@ -90,8 +111,58 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_alpha, ) = make_crossQ_optimizer(cfg, loss_module) + def update_qloss(sampled_tensordict): + optimizer_critic.zero_grad(set_to_none=True) + td_loss = {} + q_loss, value_meta = loss_module.qvalue_loss(sampled_tensordict) + sampled_tensordict.set(loss_module.tensor_keys.priority, value_meta["td_error"]) + q_loss = q_loss.mean() + + # Update critic + q_loss.backward() + optimizer_critic.step() + td_loss["loss_qvalue"] = q_loss + td_loss["loss_actor"] = float("nan") + td_loss["loss_alpha"] = float("nan") + return TensorDict(td_loss, device=device).detach() + + def update_all( + sampled_tensordict: TensorDict, update_qloss=update_qloss + ): # bind update_qloss + # Compute loss + td_loss = update_qloss(sampled_tensordict) + + actor_loss, metadata_actor = loss_module.actor_loss(sampled_tensordict) + actor_loss = actor_loss.mean() + alpha_loss = loss_module.alpha_loss( + log_prob=metadata_actor["log_prob"].detach() + ).mean() + + # Update actor + (actor_loss + actor_loss).backward() + optimizer_actor.step() + + # Update alpha + optimizer_alpha.step() + + td_loss["loss_actor"] = actor_loss + td_loss["loss_alpha"] = alpha_loss + + return TensorDict(td_loss, device=device).detach() + + if compile_mode: + update_all = torch.compile(update_all, mode=compile_mode) + update_qloss = torch.compile(update_qloss, mode=compile_mode) + if cfg.network.cudagraphs: + update_all = CudaGraphModule(update_all, warmup=50) + update_qloss = CudaGraphModule(update_qloss, warmup=50) + + def update(sampled_tensordict: TensorDict, update_actor: bool): + if update_actor: + return update_all(sampled_tensordict) + return update_qloss(sampled_tensordict) + # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -106,79 +177,45 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch = cfg.collector.frames_per_batch eval_rollout_steps = cfg.env.max_episode_steps - sampling_start = time.time() update_counter = 0 delayed_updates = cfg.optim.policy_update_delay - for _, tensordict in enumerate(collector): - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + tensordict = next(c_iter) # Update weights of the inference policy collector.update_policy_weights_() - pbar.update(tensordict.numel()) - - tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() - # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + pbar.update(current_frames) + tensordict = tensordict.reshape(-1) + + with timeit("rb - extend"): + # Add to replay buffer + replay_buffer.extend(tensordict) collected_frames += current_frames # Optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - ( - actor_losses, - alpha_losses, - q_losses, - ) = ([], [], []) + tds = [] for _ in range(num_updates): - # Update actor every delayed_updates update_counter += 1 update_actor = update_counter % delayed_updates == 0 # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to(device) - else: - sampled_tensordict = sampled_tensordict.clone() - - # Compute loss - q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict) - q_loss = q_loss.mean() - # Update critic - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - q_losses.append(q_loss.detach().item()) - - if update_actor: - actor_loss, metadata_actor = loss_module.actor_loss( - sampled_tensordict - ) - actor_loss = actor_loss.mean() - alpha_loss = loss_module.alpha_loss( - log_prob=metadata_actor["log_prob"] - ).mean() - - # Update actor - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - # Update alpha - optimizer_alpha.zero_grad() - alpha_loss.backward() - optimizer_alpha.step() - - actor_losses.append(actor_loss.detach().item()) - alpha_losses.append(alpha_loss.detach().item()) - + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample().to(device) + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + td_loss = update(sampled_tensordict, update_actor=update_actor) + tds.append(td_loss) # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start + tds = TensorDict.stack(tds).nanmean() episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() @@ -186,47 +223,47 @@ def main(cfg: "DictConfig"): # noqa: F821 ) episode_rewards = tensordict["next", "episode_reward"][episode_end] - # Logging metrics_to_log = {} - if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][episode_end] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( - episode_length - ) - if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses).item() - metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item() - metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], auto_cast_to_device=True, break_when_any_done=True, ) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + + # Logging + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if i % 20 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = tds["loss_qvalue"] + metrics_to_log["train/actor_loss"] = tds["loss_actor"] + metrics_to_log["train/alpha_loss"] = tds["loss_alpha"] + if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() + if i % 20 == 0: + timeit.print() + timeit.erase() collector.shutdown() if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index 9883bc50b17..a8638ed7bae 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -90,7 +90,15 @@ def make_environment(cfg): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore, device): +def make_collector( + cfg, + train_env, + actor_model_explore, + device, + compile=False, + compile_mode=None, + cudagraph=False, +): """Make collector.""" collector = SyncDataCollector( train_env, @@ -99,6 +107,8 @@ def make_collector(cfg, train_env, actor_model_explore, device): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, device=device, + compile_policy={"mode": compile_mode} if compile else False, + cudagraph_policy=cudagraph, ) collector.set_seed(cfg.env.seed) return collector @@ -147,9 +157,7 @@ def make_crossQ_agent(cfg, train_env, device): """Make CrossQ agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.action_spec - if train_env.batch_size: - action_spec = action_spec[(0,) * len(train_env.batch_size)] + action_spec = train_env.single_action_spec actor_net_kwargs = { "num_cells": cfg.network.actor_hidden_sizes, "out_features": 2 * action_spec.shape[-1], @@ -166,9 +174,10 @@ def make_crossQ_agent(cfg, train_env, device): dist_class = TanhNormal dist_kwargs = { - "low": action_spec.space.low, - "high": action_spec.space.high, + "low": torch.as_tensor(action_spec.space.low, device=device), + "high": torch.as_tensor(action_spec.space.high, device=device), "tanh_loc": False, + "safe_tanh": not cfg.network.compile, } actor_extractor = NormalParamExtractor( @@ -238,7 +247,7 @@ def make_crossQ_agent(cfg, train_env, device): # --------- -def make_loss_module(cfg, model): +def make_loss_module(cfg, model, device: torch.device|None=None): """Make loss module and target network updater.""" # Create CrossQ loss loss_module = CrossQLoss( @@ -248,7 +257,7 @@ def make_loss_module(cfg, model): loss_function=cfg.optim.loss_function, alpha_init=cfg.optim.alpha_init, ) - loss_module.make_value_estimator(gamma=cfg.optim.gamma) + loss_module.make_value_estimator(gamma=cfg.optim.gamma, device=device) return loss_module diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index be05e2fa66b..57310a5fc3d 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -12,6 +12,7 @@ from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple +import torch from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams @@ -515,7 +516,22 @@ def _default_value_estimator(self): from :obj:`torchrl.objectives.utils.DEFAULT_VALUE_FUN_PARAMS`. """ - self.make_value_estimator(self.default_value_estimator) + self.make_value_estimator( + self.default_value_estimator, device=self._default_device + ) + + @property + def _default_device(self) -> torch.device | None: + """A util to find the default device. + + Returns ``None`` if parameters are spread across multiple devices. + """ + devices = set() + for p in self.parameters(): + devices.add(p.device) + if len(devices) == 1: + return list(devices)[0] + return None def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): """Value-function constructor. diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index cfa5a332df9..eb1888fac11 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -340,6 +340,8 @@ def __init__( self._action_spec = action_spec self._make_vmap() self.reduction = reduction + # init target entropy + _ = self.target_entropy def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( @@ -513,15 +515,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: **metadata_actor, **value_metadata, } - td_out = TensorDict(out, []) - # td_out = td_out.named_apply( - # lambda name, value: ( - # _reduce(value, reduction=self.reduction) - # if name.startswith("loss_") - # else value - # ), - # batch_size=[], - # ) + td_out = TensorDict(out) return td_out @property @@ -543,6 +537,7 @@ def actor_loss( Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action. """ + tensordict = tensordict.copy() with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -584,6 +579,7 @@ def qvalue_loss( Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing the detached `"td_error"` to be used for prioritized sampling. """ + tensordict = tensordict.copy() # # compute next action with torch.no_grad(): with set_exploration_type( diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index c90f16911dc..04004e32458 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -197,6 +197,8 @@ def forward( to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the target params to be passed to the functional value network module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. Returns: An updated TensorDict with an advantage and a value_error keys as defined in the constructor. @@ -213,8 +215,14 @@ def __init__( advantage_key: NestedKey = None, value_target_key: NestedKey = None, value_key: NestedKey = None, + device: torch.device | None = None, ): super().__init__() + if device is None: + device = torch.get_default_device() + # this is saved for tracking only and should not be used to cast anything else than buffers during + # init. + self._device = device self._tensor_keys = None self.differentiable = differentiable self.skip_existing = skip_existing @@ -518,7 +526,8 @@ class TD0Estimator(ValueEstimatorBase): of the advantage entry. Defaults to ``"value_target"``. value_key (str or tuple of str, optional): [Deprecated] the value key to read from the input tensordict. Defaults to ``"state_value"``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. """ @@ -544,8 +553,9 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) self.average_rewards = average_rewards @_self_set_skip_existing @@ -664,7 +674,9 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -723,7 +735,8 @@ class TD1Estimator(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -757,8 +770,9 @@ def __init__( value_key=value_key, shifted=shifted, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) self.average_rewards = average_rewards self.time_dim = time_dim @@ -879,7 +893,9 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -943,7 +959,8 @@ class TDLambdaEstimator(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -979,9 +996,10 @@ def __init__( value_key=value_key, skip_existing=skip_existing, shifted=shifted, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) - self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) + self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device)) self.average_rewards = average_rewards self.vectorized = vectorized self.time_dim = time_dim @@ -1103,7 +1121,9 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1185,7 +1205,8 @@ class GAE(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension marked with the ``"time"`` name if any, and to the last dimension @@ -1233,9 +1254,10 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) - self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) + self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device)) self.average_gae = average_gae self.vectorized = vectorized self.time_dim = time_dim @@ -1336,7 +1358,12 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + if self.lmbda.device != device: + self.lmbda = self.lmbda.to(device) + lmbda = self.lmbda steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1417,7 +1444,12 @@ def value_estimate( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + if self.lmbda.device != device: + self.lmbda = self.lmbda.to(device) + lmbda = self.lmbda steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1506,7 +1538,8 @@ class VTrace(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -1551,13 +1584,14 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) if not isinstance(gamma, torch.Tensor): - gamma = torch.tensor(gamma, device=device) + gamma = torch.tensor(gamma, device=self._device) if not isinstance(rho_thresh, torch.Tensor): - rho_thresh = torch.tensor(rho_thresh, device=device) + rho_thresh = torch.tensor(rho_thresh, device=self._device) if not isinstance(c_thresh, torch.Tensor): - c_thresh = torch.tensor(c_thresh, device=device) + c_thresh = torch.tensor(c_thresh, device=self._device) self.register_buffer("gamma", gamma) self.register_buffer("rho_thresh", rho_thresh) @@ -1688,7 +1722,9 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward)