From 6482766b89088aa26974faaca3f18aecd1687b2d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 10:07:19 -0800 Subject: [PATCH] [Feature] GAIL compatibility with compile ghstack-source-id: 98c7602ec0343d7a83cb19bddeb579484c42e77e Pull Request resolved: https://github.com/pytorch/rl/pull/2573 --- sota-implementations/a2c/utils_atari.py | 2 +- sota-implementations/a2c/utils_mujoco.py | 2 +- .../cql/discrete_cql_config.yaml | 2 +- sota-implementations/cql/online_config.yaml | 2 +- sota-implementations/cql/utils.py | 8 +- sota-implementations/dreamer/dreamer_utils.py | 2 +- sota-implementations/gail/config.yaml | 5 + sota-implementations/gail/gail.py | 185 +++++++++++------- sota-implementations/gail/ppo_utils.py | 23 ++- sota-implementations/impala/utils.py | 2 +- sota-implementations/iql/discrete_iql.yaml | 2 +- sota-implementations/iql/online_config.yaml | 2 +- sota-implementations/iql/utils.py | 8 +- sota-implementations/ppo/utils_atari.py | 2 +- sota-implementations/ppo/utils_mujoco.py | 2 +- sota-implementations/sac/config.yaml | 2 +- sota-implementations/sac/utils.py | 8 +- sota-implementations/td3/config.yaml | 2 +- sota-implementations/td3/utils.py | 8 +- torchrl/_utils.py | 50 +++++ torchrl/data/replay_buffers/replay_buffers.py | 16 +- torchrl/data/replay_buffers/storages.py | 17 +- torchrl/data/replay_buffers/writers.py | 10 +- torchrl/data/tensor_specs.py | 6 +- 24 files changed, 249 insertions(+), 119 deletions(-) diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 0397f7dc5f3..6ff62bbe520 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -152,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment, device): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.single_full_action_spec.to(device), + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index b78f52b7eb4..5ce5ed1902d 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -94,7 +94,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.single_full_action_spec.to(device), + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index 6db31a9aa81..a9fb9bfed0c 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -14,7 +14,7 @@ collector: multi_step: 0 init_random_frames: 1000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 annealing_frames: 10000 eps_start: 1.0 diff --git a/sota-implementations/cql/online_config.yaml b/sota-implementations/cql/online_config.yaml index 5a8be9616a0..5c9e649f17f 100644 --- a/sota-implementations/cql/online_config.yaml +++ b/sota-implementations/cql/online_config.yaml @@ -15,7 +15,7 @@ collector: multi_step: 0 init_random_frames: 5_000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 1000 diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index fa7ed3119b5..306f1cdb7f1 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -124,6 +124,12 @@ def make_collector( cudagraph=False, ): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -131,7 +137,7 @@ def make_collector( frames_per_batch=cfg.collector.frames_per_batch, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, compile_policy={"mode": compile_mode} if compile else False, cudagraph_policy=cudagraph, ) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 532fe4e1fe9..7d8b9d6d618 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -546,7 +546,7 @@ def _dreamer_make_actor_real( default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=proof_environment.single_full_action_spec.to("cpu"), + spec=proof_environment.full_action_spec_unbatched.to("cpu"), ), ), SafeModule( diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml index cf6c8053037..089de2c59e4 100644 --- a/sota-implementations/gail/config.yaml +++ b/sota-implementations/gail/config.yaml @@ -41,6 +41,11 @@ gail: gp_lambda: 10.0 device: null +compile: + compile: False + compile_mode: default + cudagraphs: False + replay_buffer: dataset: halfcheetah-expert-v2 batch_size: 256 diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index b4856fa7d0d..a02845cfe4d 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -11,6 +11,8 @@ """ from __future__ import annotations +import warnings + import hydra import numpy as np import torch @@ -18,18 +20,24 @@ from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer from ppo_utils import eval_model, make_env, make_ppo_models +from tensordict.nn import CudaGraphModule + +from torchrl._utils import compile_with_warmup from torchrl.collectors import SyncDataCollector -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.objectives import ClipPPOLoss, GAILLoss +from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers from torchrl.objectives.value.advantages import GAE from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger +torch.set_float32_matmul_precision("high") + + @hydra.main(config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() @@ -71,25 +79,20 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, device), - policy=actor, - frames_per_batch=cfg.ppo.collector.frames_per_batch, - total_frames=cfg.ppo.collector.total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, + actor, critic = make_ppo_models( + cfg.env.env_name, compile=cfg.compile.compile, device=device ) # Create data buffer data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch), + storage=LazyTensorStorage( + cfg.ppo.collector.frames_per_batch, + device=device, + compilable=cfg.compile.compile, + ), sampler=SamplerWithoutReplacement(), batch_size=cfg.ppo.loss.mini_batch_size, + compilable=cfg.compile.compile, ) # Create loss and adv modules @@ -98,6 +101,7 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.ppo.loss.gae_lambda, value_network=critic, average_gae=False, + device=device, ) loss_module = ClipPPOLoss( @@ -111,8 +115,35 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) + actor_optim = torch.optim.Adam( + actor.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5 + ) + critic_optim = torch.optim.Adam( + critic.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5 + ) + optim = group_optimizers(actor_optim, critic_optim) + del actor_optim, critic_optim + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.ppo.collector.frames_per_batch, + total_frames=cfg.ppo.collector.total_frames, + device=device, + max_frames_per_traj=-1, + compile_policy={"mode": compile_mode} if compile_mode is not None else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) # Create replay buffer replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) @@ -140,32 +171,9 @@ def main(cfg: "DictConfig"): # noqa: F821 VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) ) test_env.eval() + num_network_updates = torch.zeros((), dtype=torch.int64, device=device) - # Training loop - collected_frames = 0 - num_network_updates = 0 - pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames) - - # extract cfg variables - cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs - cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr - cfg_optim_lr = cfg.ppo.optim.lr - cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon - cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon - cfg_logger_test_interval = cfg.logger.test_interval - cfg_logger_num_test_episodes = cfg.logger.num_test_episodes - - for i, data in enumerate(collector): - - log_info = {} - frames_in_batch = data.numel() - collected_frames += frames_in_batch - pbar.update(data.numel()) - - # Update discriminator - # Get expert data - expert_data = replay_buffer.sample() - expert_data = expert_data.to(device) + def update(data, expert_data, num_network_updates=num_network_updates): # Add collector data to expert data expert_data.set( discriminator_loss.tensor_keys.collector_action, @@ -178,9 +186,9 @@ def main(cfg: "DictConfig"): # noqa: F821 d_loss = discriminator_loss(expert_data) # Backward pass - discriminator_optim.zero_grad() d_loss.get("loss").backward() discriminator_optim.step() + discriminator_optim.zero_grad(set_to_none=True) # Compute discriminator reward with torch.no_grad(): @@ -190,40 +198,25 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set discriminator rewards to tensordict data.set(("next", "reward"), d_rewards) - # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] - if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( - { - "train/reward": episode_rewards.mean().item(), - "train/episode_length": episode_length.sum().item() - / len(episode_length), - } - ) # Update PPO for _ in range(cfg_loss_ppo_epochs): - # Compute GAE with torch.no_grad(): data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer + data_buffer.empty() data_buffer.extend(data_reshape) - for _, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) + for batch in data_buffer: + optim.zero_grad(set_to_none=True) # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 + alpha = torch.ones((), device=device) if cfg_optim_anneal_lr: alpha = 1 - (num_network_updates / total_network_updates) - for group in actor_optim.param_groups: - group["lr"] = cfg_optim_lr * alpha - for group in critic_optim.param_groups: + for group in optim.param_groups: group["lr"] = cfg_optim_lr * alpha if cfg_loss_anneal_clip_eps: loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) @@ -235,20 +228,68 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_loss = loss["loss_objective"] + loss["loss_entropy"] # Backward pass - actor_loss.backward() - critic_loss.backward() + (actor_loss + critic_loss).backward() # Update the networks - actor_optim.step() - critic_optim.step() - actor_optim.zero_grad() - critic_optim.zero_grad() + optim.step() + return {"dloss": d_loss, "alpha": alpha} + + if cfg.compile.compile: + update = compile_with_warmup(update, warmup=2, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + # Training loop + collected_frames = 0 + pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames) + + # extract cfg variables + cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs + cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr + cfg_optim_lr = cfg.ppo.optim.lr + cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon + cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon + cfg_logger_test_interval = cfg.logger.test_interval + cfg_logger_num_test_episodes = cfg.logger.num_test_episodes + + for i, data in enumerate(collector): + + log_info = {} + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Update discriminator + # Get expert data + expert_data = replay_buffer.sample() + expert_data = expert_data.to(device) + + metadata = update(data, expert_data) + d_loss = metadata["dloss"] + alpha = metadata["alpha"] + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) log_info.update( { - "train/actor_loss": actor_loss.item(), - "train/critic_loss": critic_loss.item(), - "train/discriminator_loss": d_loss["loss"].item(), + # "train/actor_loss": actor_loss.item(), + # "train/critic_loss": critic_loss.item(), + "train/discriminator_loss": d_loss["loss"], "train/lr": alpha * cfg_optim_lr, "train/clip_epsilon": ( alpha * cfg_loss_clip_epsilon diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index e7eb4534c45..7dcc2db6b74 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -43,7 +43,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False) # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, compile, device): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -52,9 +52,10 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec_unbatched.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec_unbatched.space.low, - "high": proof_environment.action_spec_unbatched.space.high, + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), "tanh_loc": False, + # "safe_tanh": not compile, } # Define policy architecture @@ -63,6 +64,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=num_outputs, # predict only loc num_cells=[64, 64], + device=device, ) # Initialize policy weights @@ -75,7 +77,9 @@ def make_ppo_models_state(proof_environment): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.action_spec_unbatched.shape[-1], scale_lb=1e-8 + proof_environment.action_spec_unbatched.shape[-1], + scale_lb=1e-8, + device=device, ), ) @@ -87,7 +91,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -100,6 +104,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=1, num_cells=[64, 64], + device=device, ) # Initialize value weights @@ -117,9 +122,11 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): - proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) +def make_ppo_models(env_name, compile, device): + proof_environment = make_env(env_name, device=device) + actor, critic = make_ppo_models_state( + proof_environment, compile=compile, device=device + ) return actor, critic diff --git a/sota-implementations/impala/utils.py b/sota-implementations/impala/utils.py index 738bb83bf55..e174bc2e71c 100644 --- a/sota-implementations/impala/utils.py +++ b/sota-implementations/impala/utils.py @@ -117,7 +117,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/iql/discrete_iql.yaml b/sota-implementations/iql/discrete_iql.yaml index 9245d4c4832..d28c02cf499 100644 --- a/sota-implementations/iql/discrete_iql.yaml +++ b/sota-implementations/iql/discrete_iql.yaml @@ -15,7 +15,7 @@ collector: total_frames: 20000 init_random_frames: 1000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 # logger diff --git a/sota-implementations/iql/online_config.yaml b/sota-implementations/iql/online_config.yaml index 1f7bb361e6c..64ad7466192 100644 --- a/sota-implementations/iql/online_config.yaml +++ b/sota-implementations/iql/online_config.yaml @@ -15,7 +15,7 @@ collector: multi_step: 0 init_random_frames: 5000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 # logger diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index b817d4345c1..261cb912de0 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -120,6 +120,12 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -127,7 +133,7 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 755c6311729..040259377ad 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -148,7 +148,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index e7eb4534c45..f2e08ffb129 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -87,7 +87,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/sac/config.yaml b/sota-implementations/sac/config.yaml index 29586f2e9a7..5cf531a3be2 100644 --- a/sota-implementations/sac/config.yaml +++ b/sota-implementations/sac/config.yaml @@ -12,7 +12,7 @@ collector: init_random_frames: 25000 frames_per_batch: 1000 init_env_steps: 1000 - device: cpu + device: env_per_collector: 1 reset_at_each_iter: False diff --git a/sota-implementations/sac/utils.py b/sota-implementations/sac/utils.py index 9760793c9cd..6d37f5ec3d8 100644 --- a/sota-implementations/sac/utils.py +++ b/sota-implementations/sac/utils.py @@ -107,13 +107,19 @@ def make_environment(cfg, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/sota-implementations/td3/config.yaml b/sota-implementations/td3/config.yaml index 7f7854b68b3..5bdf22ea6fa 100644 --- a/sota-implementations/td3/config.yaml +++ b/sota-implementations/td3/config.yaml @@ -13,7 +13,7 @@ collector: init_env_steps: 1000 frames_per_batch: 1000 reset_at_each_iter: False - device: cpu + device: env_per_collector: 1 num_workers: 1 diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index a9bc8140291..df81a522b3c 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -118,6 +118,12 @@ def make_environment(cfg, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -125,7 +131,7 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/torchrl/_utils.py b/torchrl/_utils.py index c81ffcc962b..45f8c433725 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -851,3 +851,53 @@ def set_mode(self, type: Any | None) -> None: cm = self._lock if not is_compiling() else nullcontext() with cm: self._mode = type + + +@wraps(torch.compile) +def compile_with_warmup(*args, warmup: int, **kwargs): + """Compile a model with warm-up. + + This function wraps :func:`~torch.compile` to add a warm-up phase. During the warm-up phase, + the original model is used. After the warm-up phase, the model is compiled using + `torch.compile`. + + Args: + *args: Arguments to be passed to `torch.compile`. + warmup (int): Number of calls to the model before compiling it. + **kwargs: Keyword arguments to be passed to `torch.compile`. + + Returns: + A callable that wraps the original model. If no model is provided, returns a + lambda function that takes a model as input and returns the wrapped model. + + Notes: + If no model is provided, this function returns a lambda function that can be + used to wrap a model later. This allows for delayed compilation of the model. + + Example: + >>> model = torch.nn.Linear(5, 3) + >>> compiled_model = compile_with_warmup(model, warmup=10) + >>> # First 10 calls use the original model + >>> # After 10 calls, the model is compiled and used + """ + if len(args): + model = args[0] + args = () + else: + model = kwargs.pop("model", None) + if model is None: + return lambda model: compile_with_warmup(model, warmup=warmup, **kwargs) + else: + count = 0 + compiled_model = model + + @wraps(model) + def count_and_compile(*model_args, **model_kwargs): + nonlocal count + nonlocal compiled_model + count += 1 + if count == warmup: + compiled_model = torch.compile(model, *args, **kwargs) + return compiled_model(*model_args, **model_kwargs) + + return count_and_compile diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 67113095af0..fbb76b5a681 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -20,9 +20,9 @@ import torch try: - from torch.compiler import is_dynamo_compiling + from torch.compiler import is_compiling except ImportError: - from torch._dynamo import is_compiling as is_dynamo_compiling + from torch._dynamo import is_compiling from tensordict import ( is_tensor_collection, @@ -617,9 +617,9 @@ def _add(self, data): return index def _extend(self, data: Sequence) -> torch.Tensor: - is_compiling = is_dynamo_compiling() + is_comp = is_compiling() nc = contextlib.nullcontext() - with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc: + with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc: if self.dim_extend > 0: data = self._transpose(data) index = self._writer.extend(data) @@ -672,7 +672,7 @@ def update_priority( @pin_memory_output def _sample(self, batch_size: int) -> Tuple[Any, dict]: - with self._replay_lock if not is_dynamo_compiling() else contextlib.nullcontext(): + with self._replay_lock if not is_compiling() else contextlib.nullcontext(): index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index data = self._storage.get(index) @@ -1159,7 +1159,9 @@ class TensorDictReplayBuffer(ReplayBuffer): def __init__(self, *, priority_key: str = "td_error", **kwargs) -> None: writer = kwargs.get("writer", None) if writer is None: - kwargs["writer"] = TensorDictRoundRobinWriter() + kwargs["writer"] = TensorDictRoundRobinWriter( + compilable=kwargs.get("compilable") + ) super().__init__(**kwargs) self.priority_key = priority_key @@ -1343,7 +1345,7 @@ def sample( @pin_memory_output def _sample(self, batch_size: int) -> Tuple[Any, dict]: - with self._replay_lock: + with self._replay_lock if not is_compiling() else contextlib.nullcontext(): index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index data = self._storage.get(index) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index ae0d97b7bab..52d137208ad 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -69,7 +69,7 @@ def __init__( self.max_size = int(max_size) self.checkpointer = checkpointer self._compilable = compilable - self._attached_entities_set = set() + self._attached_entities_list = [] @property def checkpointer(self): @@ -86,17 +86,17 @@ def _is_full(self): return len(self) == self.max_size @property - def _attached_entities(self): + def _attached_entities(self) -> List: # RBs that use a given instance of Storage should add # themselves to this set. - _attached_entities_set = getattr(self, "_attached_entities_set", None) - if _attached_entities_set is None: - self._attached_entities_set = _attached_entities_set = set() - return _attached_entities_set + _attached_entities_list = getattr(self, "_attached_entities_list", None) + if _attached_entities_list is None: + self._attached_entities_list = _attached_entities_list = [] + return _attached_entities_list @torch._dynamo.assume_constant_result def _attached_entities_iter(self): - return list(self._attached_entities) + return self._attached_entities @abc.abstractmethod def set(self, cursor: int, data: Any, *, set_cursor: bool = True): @@ -123,7 +123,8 @@ def attach(self, buffer: Any) -> None: Args: buffer: the object that reads from this storage. """ - self._attached_entities.add(buffer) + if buffer not in self._attached_entities: + self._attached_entities.append(buffer) def __getitem__(self, item): return self.get(item) diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 7fb865453d6..e7f4da9c4bb 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -176,7 +176,7 @@ def add(self, data: Any) -> int | torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(_cursor, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -302,7 +302,7 @@ def add(self, data: Any) -> int | torch.Tensor: ) self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -332,7 +332,7 @@ def extend(self, data: Sequence) -> torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -533,7 +533,7 @@ def add(self, data: Any) -> int | torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -567,7 +567,7 @@ def extend(self, data: TensorDictBase) -> None: device = getattr(self._storage, "device", None) out_index = torch.full(data.shape, -1, dtype=torch.long, device=device) index = self._replicate_index(out_index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 1898e679717..ad29b63db04 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3319,9 +3319,9 @@ def __init__( self.update_mask(mask) self._provisional_n = None - @torch.compiler.assume_constant_result + @property def _undefined_n(self): - return self.space.n == -1 + return self.space.n < 0 def enumerate(self) -> torch.Tensor: dtype = self.dtype @@ -3388,7 +3388,7 @@ def set_provisional_n(self, n: int): self._provisional_n = n def rand(self, shape: torch.Size = None) -> torch.Tensor: - if self._undefined_n(): + if self._undefined_n: if self._provisional_n is None: raise RuntimeError( "Cannot generate random categorical samples for undefined cardinality (n=-1). "