From db0e41603f245cfe585b5b372de73a7ed694420c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 10:22:54 -0800 Subject: [PATCH] [Feature] PPO compatibility with compile ghstack-source-id: 473ad225a61b6c455b7f6cc792f9d9cca72cabb6 Pull Request resolved: https://github.com/pytorch/rl/pull/2652 --- sota-implementations/dqn/dqn_atari.py | 2 + sota-implementations/dqn/dqn_cartpole.py | 2 + sota-implementations/iql/iql_online.py | 4 +- sota-implementations/ppo/config_atari.yaml | 6 + sota-implementations/ppo/config_mujoco.yaml | 6 + sota-implementations/ppo/ppo_atari.py | 174 +++++++++++-------- sota-implementations/ppo/ppo_mujoco.py | 182 ++++++++++++-------- sota-implementations/ppo/utils_atari.py | 27 +-- sota-implementations/ppo/utils_mujoco.py | 16 +- torchrl/_utils.py | 6 +- 10 files changed, 257 insertions(+), 168 deletions(-) diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 255b6b2ee65..d43ac25c822 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -28,6 +28,8 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_dqn_model, make_env +torch.set_float32_matmul_precision("high") + @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 89a1e04d586..873cf278d4b 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -23,6 +23,8 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils_cartpole import eval_model, make_dqn_model, make_env +torch.set_float32_matmul_precision("high") + @hydra.main(config_path="", config_name="config_cartpole", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 499c2164b52..6d904115e0d 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -103,7 +103,9 @@ def main(cfg: "DictConfig"): # noqa: F821 compile_mode = "reduce-overhead" # Create collector - collector = make_collector(cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode) + collector = make_collector( + cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode + ) # Create loss loss_module, target_net_updater = make_loss(cfg.loss, model) diff --git a/sota-implementations/ppo/config_atari.yaml b/sota-implementations/ppo/config_atari.yaml index 31e6f13a58c..f7a340e3512 100644 --- a/sota-implementations/ppo/config_atari.yaml +++ b/sota-implementations/ppo/config_atari.yaml @@ -25,6 +25,7 @@ optim: weight_decay: 0.0 max_grad_norm: 0.5 anneal_lr: True + device: # loss loss: @@ -37,3 +38,8 @@ loss: critic_coef: 1.0 entropy_coef: 0.01 loss_critic_type: l2 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/ppo/config_mujoco.yaml b/sota-implementations/ppo/config_mujoco.yaml index 2dd3c6cc229..822aea89616 100644 --- a/sota-implementations/ppo/config_mujoco.yaml +++ b/sota-implementations/ppo/config_mujoco.yaml @@ -22,6 +22,7 @@ optim: lr: 3e-4 weight_decay: 0.0 anneal_lr: True + device: # loss loss: @@ -34,3 +35,8 @@ loss: critic_coef: 0.25 entropy_coef: 0.0 loss_critic_type: l2 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 7878a0286e3..f5f0ae41945 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -9,30 +9,42 @@ """ from __future__ import annotations +import warnings + import hydra -from torchrl._utils import logger as torchrl_logger -from torchrl.record import VideoRecorder + +from torchrl._utils import compile_with_warmup @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - import time - import torch.optim import tqdm from tensordict import TensorDict + from tensordict.nn import CudaGraphModule + + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value.advantages import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_parallel_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + torch.set_float32_matmul_precision("high") + + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Correct for frame_skip frame_skip = 4 @@ -41,9 +53,17 @@ def main(cfg: "DictConfig"): # noqa: F821 mini_batch_size = cfg.loss.mini_batch_size // frame_skip test_interval = cfg.logger.test_interval // frame_skip + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create models (check utils_atari.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) + actor, critic = make_ppo_models(cfg.env.env_name, device=device) # Create collector collector = SyncDataCollector( @@ -54,14 +74,17 @@ def main(cfg: "DictConfig"): # noqa: F821 device="cpu", storing_device="cpu", max_frames_per_traj=-1, + compile_policy={"mode": compile_mode} if compile_mode else False, + cudagraph_policy=cfg.compile.cudagraphs, ) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), + storage=LazyTensorStorage(frames_per_batch, compilable=cfg.compile.compile), sampler=sampler, batch_size=mini_batch_size, + compilable=cfg.compile.compile, ) # Create loss and adv modules @@ -121,15 +144,52 @@ def main(cfg: "DictConfig"): # noqa: F821 # Main loop collected_frames = 0 - num_network_updates = 0 - start_time = time.time() + num_network_updates = torch.zeros((), dtype=torch.int64, device=device) pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = ( (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches ) - sampling_start = time.time() + def update(batch, num_network_updates): + optim.zero_grad(set_to_none=True) + + # Linearly decrease the learning rate and clip epsilon + alpha = torch.ones((), device=device) + if cfg_optim_anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + if cfg_loss_anneal_clip_eps: + loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) + num_network_updates += 1 + # Get a data batch + batch = batch.to(device, non_blocking=True) + + # Forward pass PPO loss + loss = loss_module(batch) + loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), max_norm=cfg_optim_max_grad_norm + ) + + # Update the networks + optim.step() + return loss.detach().set("alpha", alpha), num_network_updates.clone() + + if cfg.compile.compile: + update = compile_with_warmup(update, mode=compile_mode, warmup=1) + adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1) + + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + adv_module = CudaGraphModule(adv_module) # extract cfg variables cfg_loss_ppo_epochs = cfg.loss.ppo_epochs @@ -142,13 +202,16 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg.loss.clip_epsilon = cfg_loss_clip_epsilon losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches]) - for i, data in enumerate(collector): + collector_iter = iter(collector) + + for i in range(len(collector)): + with timeit("collecting"): + data = next(collector_iter) log_info = {} - sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip - pbar.update(data.numel()) + pbar.update(frames_in_batch) # Get training rewards and episode lengths episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] @@ -162,96 +225,65 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - training_start = time.time() - for j in range(cfg_loss_ppo_epochs): - - # Compute GAE - with torch.no_grad(): - data = adv_module(data.to(device, non_blocking=True)) - data_reshape = data.reshape(-1) - # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg_optim_anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in optim.param_groups: - group["lr"] = cfg_optim_lr * alpha - if cfg_loss_anneal_clip_eps: - loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) - num_network_updates += 1 - # Get a data batch - batch = batch.to(device, non_blocking=True) - - # Forward pass PPO loss - loss = loss_module(batch) - losses[j, k] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) - # Backward pass - loss_sum.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg_optim_max_grad_norm - ) - - # Update the networks - optim.step() - optim.zero_grad() + with timeit("training"): + for j in range(cfg_loss_ppo_epochs): + + # Compute GAE + with torch.no_grad(), timeit("adv"): + data = adv_module(data.to(device)) + with timeit("rb - extend"): + # Update the data buffer + data_reshape = data.reshape(-1) + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + torch.compiler.cudagraph_mark_step_begin() + loss, num_network_updates = update( + batch, num_network_updates=num_network_updates + ) + losses[j, k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ) # Get training losses and times - training_time = time.time() - training_start losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses_mean.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { - "train/lr": alpha * cfg_optim_lr, - "train/sampling_time": sampling_time, - "train/training_time": training_time, - "train/clip_epsilon": alpha * cfg_loss_clip_epsilon, + "train/lr": loss["alpha"] * cfg_optim_lr, + "train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon, } ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: actor.eval() - eval_start = time.time() test_rewards = eval_model( actor, test_env, num_episodes=cfg_logger_num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "eval/reward": test_rewards.mean(), - "eval/time": eval_time, } ) actor.train() - if logger: + log_info.update(timeit.todict(prefix="time")) for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() - sampling_start = time.time() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") - if __name__ == "__main__": main() diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index c1d6fe52585..2252a68b3de 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -9,30 +9,43 @@ """ from __future__ import annotations +import warnings + import hydra -from torchrl._utils import logger as torchrl_logger -from torchrl.record import VideoRecorder + +from torchrl._utils import compile_with_warmup @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - import time - import torch.optim import tqdm from tensordict import TensorDict + from tensordict.nn import CudaGraphModule + + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import ClipPPOLoss + from torchrl.objectives import ClipPPOLoss, group_optimizers from torchrl.objectives.value.advantages import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_mujoco import eval_model, make_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + torch.set_float32_matmul_precision("high") + + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size total_network_updates = ( (cfg.collector.total_frames // cfg.collector.frames_per_batch) @@ -40,9 +53,17 @@ def main(cfg: "DictConfig"): # noqa: F821 * num_mini_batches ) + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) + actor, critic = make_ppo_models(cfg.env.env_name, device=device) # Create collector collector = SyncDataCollector( @@ -53,14 +74,19 @@ def main(cfg: "DictConfig"): # noqa: F821 device=device, storing_device=device, max_frames_per_traj=-1, + compile_policy={"mode": compile_mode} if compile_mode else False, + cudagraph_policy=cfg.compile.cudagraphs, ) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.collector.frames_per_batch), + storage=LazyTensorStorage( + cfg.collector.frames_per_batch, compilable=cfg.compile.compile + ), sampler=sampler, batch_size=cfg.loss.mini_batch_size, + compilable=cfg.compile.compile, ) # Create loss and adv modules @@ -84,6 +110,8 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizers actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr, eps=1e-5) critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr, eps=1e-5) + optim = group_optimizers(actor_optim, critic_optim) + del actor_optim, critic_optim # Create logger logger = None @@ -111,14 +139,48 @@ def main(cfg: "DictConfig"): # noqa: F821 ) test_env.eval() + def update(batch, num_network_updates): + optim.zero_grad(set_to_none=True) + # Linearly decrease the learning rate and clip epsilon + alpha = torch.ones((), device=device) + if cfg_optim_anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + if cfg_loss_anneal_clip_eps: + loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) + num_network_updates += 1 + + # Forward pass PPO loss + loss = loss_module(batch) + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss["loss_entropy"] + total_loss = critic_loss + actor_loss + + # Backward pass + total_loss.backward() + + # Update the networks + optim.step() + return loss.detach().set("alpha", alpha), num_network_updates.clone() + + if cfg.compile.compile: + update = compile_with_warmup(update, mode=compile_mode, warmup=1) + adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1) + + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + adv_module = CudaGraphModule(adv_module) + # Main loop collected_frames = 0 - num_network_updates = 0 - start_time = time.time() + num_network_updates = torch.zeros((), dtype=torch.int64, device=device) pbar = tqdm.tqdm(total=cfg.collector.total_frames) - sampling_start = time.time() - # extract cfg variables cfg_loss_ppo_epochs = cfg.loss.ppo_epochs cfg_optim_anneal_lr = cfg.optim.anneal_lr @@ -129,13 +191,16 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg_logger_num_test_episodes = cfg.logger.num_test_episodes losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches]) - for i, data in enumerate(collector): + collector_iter = iter(collector) + + for i in range(len(collector)): + with timeit("collecting"): + data = next(collector_iter) log_info = {} - sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch - pbar.update(data.numel()) + pbar.update(frames_in_batch) # Get training rewards and episode lengths episode_rewards = data["next", "episode_reward"][data["next", "done"]] @@ -149,100 +214,67 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - training_start = time.time() - for j in range(cfg_loss_ppo_epochs): - - # Compute GAE - with torch.no_grad(): - data = adv_module(data) - data_reshape = data.reshape(-1) - - # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg_optim_anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in actor_optim.param_groups: - group["lr"] = cfg_optim_lr * alpha - for group in critic_optim.param_groups: - group["lr"] = cfg_optim_lr * alpha - if cfg_loss_anneal_clip_eps: - loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) - num_network_updates += 1 - - # Forward pass PPO loss - loss = loss_module(batch) - losses[j, k] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - critic_loss = loss["loss_critic"] - actor_loss = loss["loss_objective"] + loss["loss_entropy"] - - # Backward pass - actor_loss.backward() - critic_loss.backward() - - # Update the networks - actor_optim.step() - critic_optim.step() - actor_optim.zero_grad() - critic_optim.zero_grad() + with timeit("training"): + for j in range(cfg_loss_ppo_epochs): + + # Compute GAE + with torch.no_grad(), timeit("adv"): + data = adv_module(data.to(device)) + with timeit("rb - extend"): + # Update the data buffer + data_reshape = data.reshape(-1) + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + torch.compiler.cudagraph_mark_step_begin() + loss, num_network_updates = update( + batch, num_network_updates=num_network_updates + ) + losses[j, k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ) # Get training losses and times - training_time = time.time() - training_start losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses_mean.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { - "train/lr": alpha * cfg_optim_lr, - "train/sampling_time": sampling_time, - "train/training_time": training_time, - "train/clip_epsilon": alpha * cfg_loss_clip_epsilon + "train/lr": loss["alpha"] * cfg_optim_lr, + "train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon if cfg_loss_anneal_clip_eps else cfg_loss_clip_epsilon, } ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( i * frames_in_batch ) // cfg_logger_test_interval: actor.eval() - eval_start = time.time() test_rewards = eval_model( actor, test_env, num_episodes=cfg_logger_num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "eval/reward": test_rewards.mean(), - "eval/time": eval_time, } ) actor.train() if logger: + log_info.update(timeit.todict(prefix="time")) for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() - sampling_start = time.time() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 040259377ad..ab39c102106 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -31,7 +31,6 @@ ActorValueOperator, ConvNet, MLP, - OneHotCategorical, ProbabilisticActor, TanhNormal, ValueOperator, @@ -51,6 +50,7 @@ def make_base_env(env_name="BreakoutNoFrameskip-v4", frame_skip=4, is_test=False from_pixels=True, pixels_only=False, device="cpu", + categorical_action_encoding=True, ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) @@ -86,7 +86,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): # -------------------------------------------------------------------- -def make_ppo_modules_pixels(proof_environment): +def make_ppo_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape @@ -94,14 +94,14 @@ def make_ppo_modules_pixels(proof_environment): # Define distribution class and kwargs if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox): num_outputs = proof_environment.action_spec_unbatched.space.n - distribution_class = OneHotCategorical + distribution_class = torch.distributions.Categorical distribution_kwargs = {} else: # is ContinuousBox num_outputs = proof_environment.action_spec_unbatched.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec_unbatched.space.low, - "high": proof_environment.action_spec_unbatched.space.high, + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), } # Define input keys @@ -113,6 +113,7 @@ def make_ppo_modules_pixels(proof_environment): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], + device=device, ) common_cnn_output = common_cnn(torch.ones(input_shape)) common_mlp = MLP( @@ -121,6 +122,7 @@ def make_ppo_modules_pixels(proof_environment): activate_last_layer=True, out_features=512, num_cells=[], + device=device, ) common_mlp_output = common_mlp(common_cnn_output) @@ -137,6 +139,7 @@ def make_ppo_modules_pixels(proof_environment): out_features=num_outputs, activation_class=torch.nn.ReLU, num_cells=[], + device=device, ) policy_module = TensorDictModule( module=policy_net, @@ -148,7 +151,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.full_action_spec_unbatched, + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -161,6 +164,7 @@ def make_ppo_modules_pixels(proof_environment): in_features=common_mlp_output.shape[-1], out_features=1, num_cells=[], + device=device, ) value_module = ValueOperator( value_net, @@ -170,11 +174,12 @@ def make_ppo_modules_pixels(proof_environment): return common_module, policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device): - proof_environment = make_parallel_env(env_name, 1, device="cpu") + proof_environment = make_parallel_env(env_name, 1, device=device) common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment + proof_environment, + device=device, ) # Wrap modules in a single ActorCritic operator @@ -185,8 +190,8 @@ def make_ppo_models(env_name): ) with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) + td = proof_environment.fake_tensordict().expand(10) + actor_critic(td) del td actor = actor_critic.get_policy_operator() diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index f2e08ffb129..584945013dc 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -43,7 +43,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False) # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec_unbatched.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec_unbatched.space.low, - "high": proof_environment.action_spec_unbatched.space.high, + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), "tanh_loc": False, } @@ -63,6 +63,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=num_outputs, # predict only loc num_cells=[64, 64], + device=device, ) # Initialize policy weights @@ -87,7 +88,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.full_action_spec_unbatched, + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -100,6 +101,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=1, num_cells=[64, 64], + device=device, ) # Initialize value weights @@ -117,9 +119,9 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): - proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) +def make_ppo_models(env_name, device): + proof_environment = make_env(env_name, device=device) + actor, critic = make_ppo_models_state(proof_environment, device=device) return actor, critic diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 45f8c433725..cc1621d8723 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -854,7 +854,7 @@ def set_mode(self, type: Any | None) -> None: @wraps(torch.compile) -def compile_with_warmup(*args, warmup: int, **kwargs): +def compile_with_warmup(*args, warmup: int = 1, **kwargs): """Compile a model with warm-up. This function wraps :func:`~torch.compile` to add a warm-up phase. During the warm-up phase, @@ -863,7 +863,7 @@ def compile_with_warmup(*args, warmup: int, **kwargs): Args: *args: Arguments to be passed to `torch.compile`. - warmup (int): Number of calls to the model before compiling it. + warmup (int): Number of calls to the model before compiling it. Defaults to 1. **kwargs: Keyword arguments to be passed to `torch.compile`. Returns: @@ -888,7 +888,7 @@ def compile_with_warmup(*args, warmup: int, **kwargs): if model is None: return lambda model: compile_with_warmup(model, warmup=warmup, **kwargs) else: - count = 0 + count = -1 compiled_model = model @wraps(model)