From e647a7a48800570963349d63a34571ae7923a3f8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 09:47:07 -0800 Subject: [PATCH] [Feature] DQN compatibility with compile ghstack-source-id: b853800ff1661109f635a60e84a7534d74988b09 Pull Request resolved: https://github.com/pytorch/rl/pull/2571 --- sota-implementations/dqn/config_atari.yaml | 5 + sota-implementations/dqn/config_cartpole.yaml | 5 + sota-implementations/dqn/dqn_atari.py | 113 ++++++++++-------- sota-implementations/dqn/dqn_cartpole.py | 105 +++++++++------- 4 files changed, 139 insertions(+), 89 deletions(-) diff --git a/sota-implementations/dqn/config_atari.yaml b/sota-implementations/dqn/config_atari.yaml index 50e374cef14..021e7fd6132 100644 --- a/sota-implementations/dqn/config_atari.yaml +++ b/sota-implementations/dqn/config_atari.yaml @@ -39,3 +39,8 @@ loss: gamma: 0.99 hard_update_freq: 10_000 num_updates: 1 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/dqn/config_cartpole.yaml b/sota-implementations/dqn/config_cartpole.yaml index 9a69762d6bd..58be7fb3bb5 100644 --- a/sota-implementations/dqn/config_cartpole.yaml +++ b/sota-implementations/dqn/config_cartpole.yaml @@ -38,3 +38,8 @@ loss: gamma: 0.99 hard_update_freq: 50 num_updates: 1 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 4f37502ab76..d8e5047e8b3 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -10,14 +10,14 @@ from __future__ import annotations import tempfile -import time +import warnings import hydra import torch.nn import torch.optim import tqdm -from tensordict.nn import TensorDictSequential -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule, TensorDictSequential +from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer @@ -60,18 +60,6 @@ def main(cfg: "DictConfig"): # noqa: F821 greedy_module, ).to(device) - # Create the collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, frame_skip, device), - policy=model_explore, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - init_random_frames=init_random_frames, - ) - # Create the replay buffer if cfg.buffer.scratch_dir is None: tempdir = tempfile.TemporaryDirectory() @@ -129,25 +117,68 @@ def main(cfg: "DictConfig"): # noqa: F821 ) test_env.eval() + def update(sampled_tensordict): + loss_td = loss_module(sampled_tensordict) + q_loss = loss_td["loss"] + optimizer.zero_grad() + q_loss.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=max_grad + ) + optimizer.step() + target_net_updater.step() + return q_loss.detach() + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + # Create the collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, frame_skip, device), + policy=model_explore, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + init_random_frames=init_random_frames, + compile_policy={"mode": compile_mode} if compile_mode is not None else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) + # Main loop collected_frames = 0 - start_time = time.time() - sampling_start = time.time() num_updates = cfg.loss.num_updates max_grad = cfg.optim.max_grad_norm num_test_episodes = cfg.logger.num_test_episodes q_losses = torch.zeros(num_updates, device=device) pbar = tqdm.tqdm(total=total_frames) - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start pbar.update(data.numel()) data = data.reshape(-1) current_frames = data.numel() * frame_skip collected_frames += current_frames greedy_module.step(current_frames) - replay_buffer.extend(data) + with timeit("rb - extend"): + replay_buffer.extend(data) # Get and log training rewards and episode lengths episode_rewards = data["next", "episode_reward"][data["next", "done"]] @@ -169,24 +200,13 @@ def main(cfg: "DictConfig"): # noqa: F821 continue # optimization steps - training_start = time.time() for j in range(num_updates): - - sampled_tensordict = replay_buffer.sample() - sampled_tensordict = sampled_tensordict.to(device) - - loss_td = loss_module(sampled_tensordict) - q_loss = loss_td["loss"] - optimizer.zero_grad() - q_loss.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=max_grad - ) - optimizer.step() - target_net_updater.step() - q_losses[j].copy_(q_loss.detach()) - - training_time = time.time() - training_start + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample() + sampled_tensordict = sampled_tensordict.to(device) + with timeit("update"): + q_loss = update(sampled_tensordict) + q_losses[j].copy_(q_loss) # Get and log q-values, loss, epsilon, sampling time and training time log_info.update( @@ -195,31 +215,33 @@ def main(cfg: "DictConfig"): # noqa: F821 / frames_per_batch, "train/q_loss": q_losses.mean().item(), "train/epsilon": greedy_module.eps, - "train/sampling_time": sampling_time, - "train/training_time": training_time, } ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: model.eval() - eval_start = time.time() test_rewards = eval_model( model, test_env, num_episodes=num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "eval/reward": test_rewards, - "eval/eval_time": eval_time, } ) model.train() + if i % 200 == 0: + timeit.print() + log_info.update(timeit.todict(prefix="time")) + timeit.erase() + # Log all the information if logger: for key, value in log_info.items(): @@ -227,16 +249,11 @@ def main(cfg: "DictConfig"): # noqa: F821 # update weights of the inference policy collector.update_policy_weights_() - sampling_start = time.time() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") - if __name__ == "__main__": main() diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index b97d8c904fd..e51a538d882 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -2,17 +2,18 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + from __future__ import annotations -import time +import warnings import hydra import torch.nn import torch.optim import tqdm -from tensordict.nn import TensorDictSequential -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule, TensorDictSequential +from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.envs import ExplorationType, set_exploration_type @@ -48,18 +49,6 @@ def main(cfg: "DictConfig"): # noqa: F821 greedy_module, ).to(device) - # Create the collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, "cpu"), - policy=model_explore, - frames_per_batch=cfg.collector.frames_per_batch, - total_frames=cfg.collector.total_frames, - device="cpu", - storing_device="cpu", - max_frames_per_traj=-1, - init_random_frames=cfg.collector.init_random_frames, - ) - # Create the replay buffer replay_buffer = TensorDictReplayBuffer( pin_memory=False, @@ -111,9 +100,47 @@ def main(cfg: "DictConfig"): # noqa: F821 ), ) + def update(sampled_tensordict): + loss_td = loss_module(sampled_tensordict) + q_loss = loss_td["loss"] + optimizer.zero_grad() + q_loss.backward() + optimizer.step() + target_net_updater.step() + return q_loss.detach() + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + # Create the collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, "cpu"), + policy=model_explore, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device="cpu", + storing_device="cpu", + max_frames_per_traj=-1, + init_random_frames=cfg.collector.init_random_frames, + compile_policy={"mode": compile_mode} if compile_mode is not None else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) + # Main loop collected_frames = 0 - start_time = time.time() num_updates = cfg.loss.num_updates batch_size = cfg.buffer.batch_size test_interval = cfg.logger.test_interval @@ -121,17 +148,20 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch = cfg.collector.frames_per_batch pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - sampling_start = time.time() q_losses = torch.zeros(num_updates, device=device) - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start pbar.update(data.numel()) data = data.reshape(-1) current_frames = data.numel() - replay_buffer.extend(data) + + with timeit("rb - extend"): + replay_buffer.extend(data) collected_frames += current_frames greedy_module.step(current_frames) @@ -156,18 +186,13 @@ def main(cfg: "DictConfig"): # noqa: F821 continue # optimization steps - training_start = time.time() for j in range(num_updates): - sampled_tensordict = replay_buffer.sample(batch_size) - sampled_tensordict = sampled_tensordict.to(device) - loss_td = loss_module(sampled_tensordict) - q_loss = loss_td["loss"] - optimizer.zero_grad() - q_loss.backward() - optimizer.step() - target_net_updater.step() - q_losses[j].copy_(q_loss.detach()) - training_time = time.time() - training_start + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample(batch_size) + sampled_tensordict = sampled_tensordict.to(device) + with timeit("update"): + q_loss = update(sampled_tensordict) + q_losses[j].copy_(q_loss) # Get and log q-values, loss, epsilon, sampling time and training time log_info.update( @@ -176,29 +201,31 @@ def main(cfg: "DictConfig"): # noqa: F821 / frames_per_batch, "train/q_loss": q_losses.mean().item(), "train/epsilon": greedy_module.eps, - "train/sampling_time": sampling_time, - "train/training_time": training_time, } ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: model.eval() - eval_start = time.time() test_rewards = eval_model(model, test_env, num_test_episodes) - eval_time = time.time() - eval_start model.train() log_info.update( { "eval/reward": test_rewards, - "eval/eval_time": eval_time, } ) + if i % 200 == 0: + timeit.print() + log_info.update(timeit.todict(prefix="time")) + timeit.erase() + # Log all the information if logger: for key, value in log_info.items(): @@ -206,14 +233,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # update weights of the inference policy collector.update_policy_weights_() - sampling_start = time.time() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__":