From 0fcd4b61a44fe829486834969d230882c0353fbf Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 09:47:08 -0800 Subject: [PATCH] [Feature] GAIL compatibility with compile ghstack-source-id: d4928af8e8ba023f2808b752cc7319157ffdfbd3 Pull Request resolved: https://github.com/pytorch/rl/pull/2573 --- .../cql/discrete_cql_config.yaml | 2 +- sota-implementations/cql/online_config.yaml | 2 +- sota-implementations/cql/utils.py | 8 +- sota-implementations/gail/config.yaml | 5 + sota-implementations/gail/gail.py | 172 +++++++++++------- sota-implementations/gail/ppo_utils.py | 18 +- sota-implementations/iql/discrete_iql.yaml | 2 +- sota-implementations/iql/online_config.yaml | 2 +- sota-implementations/iql/utils.py | 8 +- sota-implementations/sac/config.yaml | 2 +- sota-implementations/sac/utils.py | 8 +- sota-implementations/td3/config.yaml | 2 +- sota-implementations/td3/utils.py | 8 +- 13 files changed, 152 insertions(+), 87 deletions(-) diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index 6db31a9aa81..a9fb9bfed0c 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -14,7 +14,7 @@ collector: multi_step: 0 init_random_frames: 1000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 annealing_frames: 10000 eps_start: 1.0 diff --git a/sota-implementations/cql/online_config.yaml b/sota-implementations/cql/online_config.yaml index 5a8be9616a0..5c9e649f17f 100644 --- a/sota-implementations/cql/online_config.yaml +++ b/sota-implementations/cql/online_config.yaml @@ -15,7 +15,7 @@ collector: multi_step: 0 init_random_frames: 5_000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 1000 diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index fa7ed3119b5..306f1cdb7f1 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -124,6 +124,12 @@ def make_collector( cudagraph=False, ): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -131,7 +137,7 @@ def make_collector( frames_per_batch=cfg.collector.frames_per_batch, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, compile_policy={"mode": compile_mode} if compile else False, cudagraph_policy=cudagraph, ) diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml index cf6c8053037..2e057b08220 100644 --- a/sota-implementations/gail/config.yaml +++ b/sota-implementations/gail/config.yaml @@ -41,6 +41,11 @@ gail: gp_lambda: 10.0 device: null +compile: + compile: False + compile_mode: + cudagraphs: False + replay_buffer: dataset: halfcheetah-expert-v2 batch_size: 256 diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index b4856fa7d0d..39eaf3a929d 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -11,6 +11,8 @@ """ from __future__ import annotations +import warnings + import hydra import numpy as np import torch @@ -18,13 +20,15 @@ from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer from ppo_utils import eval_model, make_env, make_ppo_models +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.objectives import ClipPPOLoss, GAILLoss +from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers from torchrl.objectives.value.advantages import GAE from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger @@ -71,18 +75,8 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, device), - policy=actor, - frames_per_batch=cfg.ppo.collector.frames_per_batch, - total_frames=cfg.ppo.collector.total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, + actor, critic = make_ppo_models( + cfg.env.env_name, compile=cfg.compile.compile, device=device ) # Create data buffer @@ -111,8 +105,36 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) + actor_optim = torch.optim.Adam( + actor.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5 + ) + critic_optim = torch.optim.Adam( + critic.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5 + ) + optim = group_optimizers(actor_optim, critic_optim) + del actor_optim, critic_optim + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.ppo.collector.frames_per_batch, + total_frames=cfg.ppo.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + compile_policy={"mode": compile_mode} if compile_mode is not None else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) # Create replay buffer replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) @@ -140,32 +162,9 @@ def main(cfg: "DictConfig"): # noqa: F821 VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) ) test_env.eval() + num_network_updates = torch.zeros((), dtype=torch.int64, device=device) - # Training loop - collected_frames = 0 - num_network_updates = 0 - pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames) - - # extract cfg variables - cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs - cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr - cfg_optim_lr = cfg.ppo.optim.lr - cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon - cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon - cfg_logger_test_interval = cfg.logger.test_interval - cfg_logger_num_test_episodes = cfg.logger.num_test_episodes - - for i, data in enumerate(collector): - - log_info = {} - frames_in_batch = data.numel() - collected_frames += frames_in_batch - pbar.update(data.numel()) - - # Update discriminator - # Get expert data - expert_data = replay_buffer.sample() - expert_data = expert_data.to(device) + def update(data, expert_data, num_network_updates=num_network_updates): # Add collector data to expert data expert_data.set( discriminator_loss.tensor_keys.collector_action, @@ -178,9 +177,9 @@ def main(cfg: "DictConfig"): # noqa: F821 d_loss = discriminator_loss(expert_data) # Backward pass - discriminator_optim.zero_grad() d_loss.get("loss").backward() discriminator_optim.step() + discriminator_optim.zero_grad(set_to_none=True) # Compute discriminator reward with torch.no_grad(): @@ -190,40 +189,25 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set discriminator rewards to tensordict data.set(("next", "reward"), d_rewards) - # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] - if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( - { - "train/reward": episode_rewards.mean().item(), - "train/episode_length": episode_length.sum().item() - / len(episode_length), - } - ) # Update PPO for _ in range(cfg_loss_ppo_epochs): - # Compute GAE with torch.no_grad(): data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer + data_buffer.empty() data_buffer.extend(data_reshape) - for _, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) + for batch in data_buffer: + optim.zero_grad(set_to_none=True) # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 + alpha = torch.ones((), device=device) if cfg_optim_anneal_lr: alpha = 1 - (num_network_updates / total_network_updates) - for group in actor_optim.param_groups: - group["lr"] = cfg_optim_lr * alpha - for group in critic_optim.param_groups: + for group in optim.param_groups: group["lr"] = cfg_optim_lr * alpha if cfg_loss_anneal_clip_eps: loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) @@ -235,20 +219,68 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_loss = loss["loss_objective"] + loss["loss_entropy"] # Backward pass - actor_loss.backward() - critic_loss.backward() + (actor_loss + critic_loss).backward() # Update the networks - actor_optim.step() - critic_optim.step() - actor_optim.zero_grad() - critic_optim.zero_grad() + optim.step() + return TensorDict(dloss=d_loss, alpha=alpha).detach() + + if cfg.compile.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + # Training loop + collected_frames = 0 + pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames) + + # extract cfg variables + cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs + cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr + cfg_optim_lr = cfg.ppo.optim.lr + cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon + cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon + cfg_logger_test_interval = cfg.logger.test_interval + cfg_logger_num_test_episodes = cfg.logger.num_test_episodes + + for i, data in enumerate(collector): + + log_info = {} + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Update discriminator + # Get expert data + expert_data = replay_buffer.sample() + expert_data = expert_data.to(device) + + metadata = update(data, expert_data) + d_loss = metadata["d_loss"] + alpha = metadata["alpha"] + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) log_info.update( { - "train/actor_loss": actor_loss.item(), - "train/critic_loss": critic_loss.item(), - "train/discriminator_loss": d_loss["loss"].item(), + # "train/actor_loss": actor_loss.item(), + # "train/critic_loss": critic_loss.item(), + "train/discriminator_loss": d_loss["loss"], "train/lr": alpha * cfg_optim_lr, "train/clip_epsilon": ( alpha * cfg_loss_clip_epsilon diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index e7eb4534c45..6ba12acdf9c 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -43,7 +43,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False) # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, compile, device): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -52,9 +52,10 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec_unbatched.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec_unbatched.space.low, - "high": proof_environment.action_spec_unbatched.space.high, + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), "tanh_loc": False, + "safe_tanh": not compile, } # Define policy architecture @@ -63,6 +64,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=num_outputs, # predict only loc num_cells=[64, 64], + device=device, ) # Initialize policy weights @@ -87,7 +89,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -117,9 +119,11 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): - proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) +def make_ppo_models(env_name, compile, device): + proof_environment = make_env(env_name, device=device) + actor, critic = make_ppo_models_state( + proof_environment, compile=compile, device=device + ) return actor, critic diff --git a/sota-implementations/iql/discrete_iql.yaml b/sota-implementations/iql/discrete_iql.yaml index 9245d4c4832..d28c02cf499 100644 --- a/sota-implementations/iql/discrete_iql.yaml +++ b/sota-implementations/iql/discrete_iql.yaml @@ -15,7 +15,7 @@ collector: total_frames: 20000 init_random_frames: 1000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 # logger diff --git a/sota-implementations/iql/online_config.yaml b/sota-implementations/iql/online_config.yaml index 1f7bb361e6c..64ad7466192 100644 --- a/sota-implementations/iql/online_config.yaml +++ b/sota-implementations/iql/online_config.yaml @@ -15,7 +15,7 @@ collector: multi_step: 0 init_random_frames: 5000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 # logger diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index b817d4345c1..261cb912de0 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -120,6 +120,12 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -127,7 +133,7 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/sota-implementations/sac/config.yaml b/sota-implementations/sac/config.yaml index 29586f2e9a7..5cf531a3be2 100644 --- a/sota-implementations/sac/config.yaml +++ b/sota-implementations/sac/config.yaml @@ -12,7 +12,7 @@ collector: init_random_frames: 25000 frames_per_batch: 1000 init_env_steps: 1000 - device: cpu + device: env_per_collector: 1 reset_at_each_iter: False diff --git a/sota-implementations/sac/utils.py b/sota-implementations/sac/utils.py index 9760793c9cd..6d37f5ec3d8 100644 --- a/sota-implementations/sac/utils.py +++ b/sota-implementations/sac/utils.py @@ -107,13 +107,19 @@ def make_environment(cfg, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/sota-implementations/td3/config.yaml b/sota-implementations/td3/config.yaml index 7f7854b68b3..5bdf22ea6fa 100644 --- a/sota-implementations/td3/config.yaml +++ b/sota-implementations/td3/config.yaml @@ -13,7 +13,7 @@ collector: init_env_steps: 1000 frames_per_batch: 1000 reset_at_each_iter: False - device: cpu + device: env_per_collector: 1 num_workers: 1 diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index a9bc8140291..df81a522b3c 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -118,6 +118,12 @@ def make_environment(cfg, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -125,7 +131,7 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector