diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index b3f8e242a9e..76066e35c4e 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -308,11 +308,12 @@ Utils :toctree: generated/ :template: rl_template_noinherit.rst + HardUpdate + SoftUpdate + ValueEstimators + default_value_kwargs distance_loss + group_optimizers hold_out_net hold_out_params next_state_value - SoftUpdate - HardUpdate - ValueEstimators - default_value_kwargs diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 73155d9fa1a..410ef1dd973 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -15,8 +15,11 @@ import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -69,6 +72,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create agent model = make_cql_model(cfg, train_env, eval_env, device) del train_env + if hasattr(eval_env, "start"): + # To set the number of threads to the definitive value + eval_env.start() # Create loss loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) @@ -81,81 +87,104 @@ def main(cfg: "DictConfig"): # noqa: F821 alpha_prime_optim, ) = make_continuous_cql_optimizer(cfg, loss_module) - pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) + # Group optimizers + optimizer = group_optimizers( + policy_optim, critic_optim, alpha_optim, alpha_prime_optim + ) - gradient_steps = cfg.optim.gradient_steps - policy_eval_start = cfg.optim.policy_eval_start - evaluation_interval = cfg.logger.eval_iter - eval_steps = cfg.logger.eval_steps - - # Training loop - start_time = time.time() - for i in range(gradient_steps): - pbar.update(1) - # sample data - data = replay_buffer.sample() - # compute loss - loss_vals = loss_module(data.clone().to(device)) + def update(data, policy_eval_start, iteration): + loss_vals = loss_module(data.to(device)) # official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks - if i >= policy_eval_start: - actor_loss = loss_vals["loss_actor"] - else: - actor_loss = loss_vals["loss_actor_bc"] + actor_loss = torch.where( + iteration >= policy_eval_start, + loss_vals["loss_actor"], + loss_vals["loss_actor_bc"], + ) q_loss = loss_vals["loss_qvalue"] cql_loss = loss_vals["loss_cql"] q_loss = q_loss + cql_loss + loss_vals["q_loss"] = q_loss # update model alpha_loss = loss_vals["loss_alpha"] alpha_prime_loss = loss_vals["loss_alpha_prime"] + if alpha_prime_loss is None: + alpha_prime_loss = 0 - alpha_optim.zero_grad() - alpha_loss.backward() - alpha_optim.step() + loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss - policy_optim.zero_grad() - actor_loss.backward() - policy_optim.step() + loss.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) - if alpha_prime_optim is not None: - alpha_prime_optim.zero_grad() - alpha_prime_loss.backward(retain_graph=True) - alpha_prime_optim.step() + # update qnet_target params + target_net_updater.step() - critic_optim.zero_grad() - # TODO: we have the option to compute losses independently retain is not needed? - q_loss.backward(retain_graph=False) - critic_optim.step() + return loss.detach(), loss_vals.detach() - loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss + compile_mode = None + if cfg.loss.compile: + if cfg.loss.compile_mode not in (None, ""): + compile_mode = cfg.loss.compile_mode + elif cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.loss.cudagraphs: + update = CudaGraphModule(update, warmup=50) + + pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) + + gradient_steps = cfg.optim.gradient_steps + policy_eval_start = cfg.optim.policy_eval_start + evaluation_interval = cfg.logger.eval_iter + eval_steps = cfg.logger.eval_steps + + # Training loop + start_time = time.time() + policy_eval_start = torch.tensor(policy_eval_start, device=device) + for i in range(gradient_steps): + pbar.update(1) + # sample data + with timeit("sample"): + data = replay_buffer.sample() + + with timeit("update"): + # compute loss + i_device = torch.tensor(i, device=device) + loss, loss_vals = update( + data.to(device), policy_eval_start=policy_eval_start, iteration=i_device + ) # log metrics to_log = { - "loss": loss.item(), - "loss_actor_bc": loss_vals["loss_actor_bc"].item(), - "loss_actor": loss_vals["loss_actor"].item(), - "loss_qvalue": q_loss.item(), - "loss_cql": cql_loss.item(), - "loss_alpha": alpha_loss.item(), - "loss_alpha_prime": alpha_prime_loss.item(), + "loss": loss.cpu(), + **loss_vals.cpu(), } - # update qnet_target params - target_net_updater.step() - # evaluation - if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_td = eval_env.rollout( - max_steps=eval_steps, policy=model[0], auto_cast_to_device=True - ) - eval_env.apply(dump_video) - eval_reward = eval_td["next", "reward"].sum(1).mean().item() - to_log["evaluation_reward"] = eval_reward - - log_metrics(logger, to_log, i) + with timeit("log/eval"): + if i % evaluation_interval == 0: + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): + eval_td = eval_env.rollout( + max_steps=eval_steps, policy=model[0], auto_cast_to_device=True + ) + eval_env.apply(dump_video) + eval_reward = eval_td["next", "reward"].sum(1).mean().item() + to_log["evaluation_reward"] = eval_reward + + with timeit("log"): + if i % 200 == 0: + to_log.update(timeit.todict(prefix="time")) + log_metrics(logger, to_log, i) + if i % 200 == 0: + timeit.print() + timeit.erase() pbar.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index cf629ed0733..c95b1af708b 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -11,15 +11,16 @@ The helper functions are coded in the utils.py associated with this script. """ -import time - import hydra import numpy as np import torch import tqdm from tensordict import TensorDict -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -82,8 +83,24 @@ def main(cfg: "DictConfig"): # noqa: F821 # create agent model = make_cql_model(cfg, train_env, eval_env, device) + compile_mode = None + if cfg.loss.compile: + if cfg.loss.compile_mode not in (None, ""): + compile_mode = cfg.loss.compile_mode + elif cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create collector - collector = make_collector(cfg, train_env, actor_model_explore=model[0]) + collector = make_collector( + cfg, + train_env, + actor_model_explore=model[0], + compile=cfg.loss.compile, + compile_mode=compile_mode, + cudagraph=cfg.loss.cudagraphs, + ) # Create loss loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) @@ -95,8 +112,37 @@ def main(cfg: "DictConfig"): # noqa: F821 alpha_optim, alpha_prime_optim, ) = make_continuous_cql_optimizer(cfg, loss_module) + optimizer = group_optimizers( + policy_optim, critic_optim, alpha_optim, alpha_prime_optim + ) + + def update(sampled_tensordict): + + loss_td = loss_module(sampled_tensordict) + + actor_loss = loss_td["loss_actor"] + q_loss = loss_td["loss_qvalue"] + cql_loss = loss_td["loss_cql"] + q_loss = q_loss + cql_loss + alpha_loss = loss_td["loss_alpha"] + alpha_prime_loss = loss_td["loss_alpha_prime"] + + total_loss = alpha_loss + actor_loss + alpha_prime_loss + q_loss + total_loss.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + # update qnet_target params + target_net_updater.step() + + return loss_td.detach() + + if compile_mode: + update = torch.compile(update, mode=compile_mode) + if cfg.loss.cudagraphs: + update = CudaGraphModule(update, warmup=50) + # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -111,69 +157,38 @@ def main(cfg: "DictConfig"): # noqa: F821 evaluation_interval = cfg.logger.log_interval eval_rollout_steps = cfg.logger.eval_steps - sampling_start = time.time() - for i, tensordict in enumerate(collector): - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + tensordict = next(c_iter) pbar.update(tensordict.numel()) # update weights of the inference policy collector.update_policy_weights_() - tensordict = tensordict.view(-1) - current_frames = tensordict.numel() - # add to replay buffer - replay_buffer.extend(tensordict.cpu()) - collected_frames += current_frames + with timeit("rb - extend"): + tensordict = tensordict.view(-1) + current_frames = tensordict.numel() + # add to replay buffer + replay_buffer.extend(tensordict) + collected_frames += current_frames - # optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - log_loss_td = TensorDict({}, [num_updates]) + log_loss_td = TensorDict(batch_size=[num_updates], device=device) for j in range(num_updates): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict.clone() - - loss_td = loss_module(sampled_tensordict) - - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_qvalue"] - cql_loss = loss_td["loss_cql"] - q_loss = q_loss + cql_loss - alpha_loss = loss_td["loss_alpha"] - alpha_prime_loss = loss_td["loss_alpha_prime"] - - alpha_optim.zero_grad() - alpha_loss.backward() - alpha_optim.step() - - policy_optim.zero_grad() - actor_loss.backward() - policy_optim.step() - - if alpha_prime_optim is not None: - alpha_prime_optim.zero_grad() - alpha_prime_loss.backward(retain_graph=True) - alpha_prime_optim.step() - - critic_optim.zero_grad() - q_loss.backward(retain_graph=False) - critic_optim.step() + with timeit("rb - sample"): + # sample from replay buffer + sampled_tensordict = replay_buffer.sample().to(device) + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss_td = update(sampled_tensordict) log_loss_td[j] = loss_td.detach() - - # update qnet_target params - target_net_updater.step() - # update priority if prb: - replay_buffer.update_priority(sampled_tensordict) + with timeit("rb - update priority"): + replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start episode_rewards = tensordict["next", "episode_reward"][ tensordict["next", "done"] ] @@ -195,36 +210,32 @@ def main(cfg: "DictConfig"): # noqa: F821 "loss_alpha_prime" ).mean() metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + if i % 10 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) # Evaluation - - prev_test_frame = ((i - 1) * frames_per_batch) // evaluation_interval - cur_test_frame = (i * frames_per_batch) // evaluation_interval - final = current_frames >= collector.total_frames - if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() - eval_rollout = eval_env.rollout( - eval_rollout_steps, - model[0], - auto_cast_to_device=True, - break_when_any_done=True, - ) - eval_time = time.time() - eval_start - eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - eval_env.apply(dump_video) - metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + with timeit("eval"): + prev_test_frame = ((i - 1) * frames_per_batch) // evaluation_interval + cur_test_frame = (i * frames_per_batch) // evaluation_interval + final = current_frames >= collector.total_frames + if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model[0], + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + eval_env.apply(dump_video) + metrics_to_log["eval/reward"] = eval_reward log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() - - collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") + if i % 10 == 0: + timeit.print() + timeit.erase() collector.shutdown() if not eval_env.is_closed: diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index 644b8ec624e..e05c73208a9 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -10,7 +10,7 @@ env: # Collector collector: frames_per_batch: 200 - total_frames: 20000 + total_frames: 1_000_000 multi_step: 0 init_random_frames: 1000 env_per_collector: 1 @@ -57,3 +57,6 @@ loss: loss_function: l2 gamma: 0.99 tau: 0.005 + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index d0d6693eb97..8d08b180175 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -10,14 +10,15 @@ The helper functions are coded in the utils.py associated with this script. """ -import time import hydra import numpy as np import torch import torch.cuda import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -71,8 +72,24 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create loss loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) + compile_mode = None + if cfg.loss.compile: + if cfg.loss.compile_mode not in (None, ""): + compile_mode = cfg.loss.compile_mode + elif cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create off-policy collector - collector = make_collector(cfg, train_env, explore_policy) + collector = make_collector( + cfg, + train_env, + explore_policy, + compile=cfg.loss.compile, + compile_mode=compile_mode, + cudagraph=cfg.loss.cudagraphs, + ) # Create replay buffer replay_buffer = make_replay_buffer( @@ -86,6 +103,28 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizers optimizer = make_discrete_cql_optimizer(cfg, loss_module) + def update(sampled_tensordict): + # Compute loss + optimizer.zero_grad(set_to_none=True) + loss_dict = loss_module(sampled_tensordict) + + q_loss = loss_dict["loss_qvalue"] + cql_loss = loss_dict["loss_cql"] + loss = q_loss + cql_loss + + # Update model + loss.backward() + optimizer.step() + + # Update target params + target_net_updater.step() + return loss_dict.detach() + + if compile_mode: + update = torch.compile(update, mode=compile_mode) + if cfg.loss.cudagraphs: + update = CudaGraphModule(update, warmup=50) + # Main loop collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -101,9 +140,11 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch - start_time = sampling_start = time.time() - for tensordict in collector: - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + tensordict = next(c_iter) # Update exploration policy explore_policy[1].step(tensordict.numel()) @@ -111,53 +152,31 @@ def main(cfg: "DictConfig"): # noqa: F821 # Update weights of the inference policy collector.update_policy_weights_() - pbar.update(tensordict.numel()) + current_frames = tensordict.numel() + pbar.update(current_frames) tensordict = tensordict.reshape(-1) - current_frames = tensordict.numel() - # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("rb - extend"): + # Add to replay buffer + replay_buffer.extend(tensordict) collected_frames += current_frames # Optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - ( - q_losses, - cql_losses, - ) = ([], []) + tds = [] for _ in range(num_updates): - # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict.clone() - - # Compute loss - loss_dict = loss_module(sampled_tensordict) - - q_loss = loss_dict["loss_qvalue"] - cql_loss = loss_dict["loss_cql"] - loss = q_loss + cql_loss + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample() + sampled_tensordict = sampled_tensordict.to(device) + with timeit("update"): + loss_dict = update(sampled_tensordict) + tds.append(loss_dict) - # Update model - optimizer.zero_grad() - loss.backward() - optimizer.step() - q_losses.append(q_loss.item()) - cql_losses.append(cql_loss.item()) - - # Update target params - target_net_updater.step() # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() @@ -165,8 +184,23 @@ def main(cfg: "DictConfig"): # noqa: F821 ) episode_rewards = tensordict["next", "episode_reward"][episode_end] - # Logging metrics_to_log = {} + # Evaluation + with timeit("eval"): + if collected_frames % eval_iter < frames_per_batch: + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model, + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + + # Logging if len(episode_rewards) > 0: episode_length = tensordict["next", "step_count"][episode_end] metrics_to_log["train/reward"] = episode_rewards.mean().item() @@ -176,33 +210,20 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log["train/epsilon"] = explore_policy[1].eps if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses) - metrics_to_log["train/cql_loss"] = np.mean(cql_losses) - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + tds = torch.stack(tds, dim=0).mean() + metrics_to_log["train/q_loss"] = tds["loss_qvalue"] + metrics_to_log["train/cql_loss"] = tds["loss_cql"] + if i % 100 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) + + if i % 100 == 0: + timeit.print() + timeit.erase() - # Evaluation - if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() - eval_rollout = eval_env.rollout( - eval_rollout_steps, - model, - auto_cast_to_device=True, - break_when_any_done=True, - ) - eval_time = time.time() - eval_start - eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/cql/offline_config.yaml b/sota-implementations/cql/offline_config.yaml index bf213d4e3c5..e78fcc0a03e 100644 --- a/sota-implementations/cql/offline_config.yaml +++ b/sota-implementations/cql/offline_config.yaml @@ -54,3 +54,6 @@ loss: num_random: 10 with_lagrange: True lagrange_thresh: 5.0 # tau + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/cql/online_config.yaml b/sota-implementations/cql/online_config.yaml index 00db1d6bb62..5b3742975f8 100644 --- a/sota-implementations/cql/online_config.yaml +++ b/sota-implementations/cql/online_config.yaml @@ -11,7 +11,7 @@ env: # Collector collector: frames_per_batch: 1000 - total_frames: 20000 + total_frames: 1_000_000 multi_step: 0 init_random_frames: 5_000 env_per_collector: 1 @@ -66,3 +66,6 @@ loss: num_random: 10 with_lagrange: True lagrange_thresh: 10.0 + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index c1d6fb52024..00f0a81c515 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -113,7 +113,14 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector( + cfg, + train_env, + actor_model_explore, + compile=False, + compile_mode=None, + cudagraph=False, +): """Make collector.""" collector = SyncDataCollector( train_env, @@ -123,6 +130,8 @@ def make_collector(cfg, train_env, actor_model_explore): max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, device=cfg.collector.device, + compile_policy={"mode": compile_mode} if compile else False, + cudagraph_policy=cudagraph, ) collector.set_seed(cfg.env.seed) return collector @@ -191,7 +200,7 @@ def make_offline_replay_buffer(rb_cfg): def make_cql_model(cfg, train_env, eval_env, device="cpu"): model_cfg = cfg.model - action_spec = train_env.action_spec + action_spec = train_env.single_action_spec actor_net, q_net = make_cql_modules_state(model_cfg, eval_env) in_keys = ["observation"] @@ -208,11 +217,10 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"): spec=action_spec, distribution_class=TanhNormal, distribution_kwargs={ - "low": action_spec.space.low[len(train_env.batch_size) :], - "high": action_spec.space.high[ - len(train_env.batch_size) : - ], # remove batch-size + "low": action_spec.space.low, + "high": action_spec.space.high, "tanh_loc": False, + "safe_tanh": not cfg.loss.compile, }, default_interaction_type=ExplorationType.RANDOM, ) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 9e59e0f69d6..17bd28c8390 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1356,12 +1356,15 @@ def _start_workers(self) -> None: from torchrl.envs.env_creator import EnvCreator + num_threads = max( + 1, torch.get_num_threads() - self.num_workers + ) # 1 more thread for this proc + if self.num_threads is None: - self.num_threads = max( - 1, torch.get_num_threads() - self.num_workers - ) # 1 more thread for this proc + self.num_threads = num_threads - torch.set_num_threads(self.num_threads) + if self.num_threads != torch.get_num_threads(): + torch.set_num_threads(self.num_threads) if self._mp_start_method is not None: ctx = mp.get_context(self._mp_start_method) diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 546d93cb228..8c332c4efed 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -9,6 +9,11 @@ from torch import autograd, distributions as d from torch.distributions import Independent, Transform, TransformedDistribution +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + def _cast_device(elt: Union[torch.Tensor, float], device) -> Union[torch.Tensor, float]: if isinstance(elt, torch.Tensor): @@ -40,10 +45,12 @@ class FasterTransformedDistribution(TransformedDistribution): __doc__ = __doc__ + TransformedDistribution.__doc__ def __init__(self, base_distribution, transforms, validate_args=None): + if is_dynamo_compiling(): + return super().__init__( + base_distribution, transforms, validate_args=validate_args + ) if isinstance(transforms, Transform): - self.transforms = [ - transforms, - ] + self.transforms = [transforms] elif isinstance(transforms, list): raise ValueError("Make a ComposeTransform first.") else: diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 7337d1c94dd..f04e0c78382 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -150,7 +150,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: action_key = self.action_key out = action_tensordict.get(action_key) - eps = self.eps.item() + eps = self.eps cond = torch.rand(action_tensordict.shape, device=out.device) < eps cond = expand_as_right(cond, out) spec = self.spec diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 1ea9ebb5998..01f993e629a 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -23,6 +23,7 @@ from .utils import ( default_value_kwargs, distance_loss, + group_optimizers, HardUpdate, hold_out_net, hold_out_params, diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index a6cb21dd2a4..55575ba2b6e 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -375,6 +375,7 @@ def __init__( ) self._make_vmap() self.reduction = reduction + _ = self.target_entropy def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index aeaa092494e..972fd200e0e 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -597,3 +597,21 @@ def _get_default_device(net): return p.device else: return torch.get_default_device() + + +def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer: + """Groups multiple optimizers into a single one. + + All optimizers are expected to have the same type. + """ + cls = None + params = [] + for optimizer in optimizers: + if optimizer is None: + continue + if cls is None: + cls = type(optimizer) + if cls is not type(optimizer): + raise ValueError("Cannot group optimizers of different type.") + params.extend(optimizer.param_groups) + return cls(params)