diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index c7f70308fd4..47e43125ea4 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -182,7 +182,10 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): lr = cfg.optim.lr c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) + with timeit("collecting"): data = next(c_iter) @@ -261,10 +264,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): "test/reward": test_rewards.mean(), } ) - if i % 200 == 0: - log_info.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + log_info.update(timeit.todict(prefix="time")) if logger: for key, value in log_info.items(): diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index cf88e7db01a..07ad5197954 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -179,7 +179,10 @@ def update(batch): pbar = tqdm.tqdm(total=cfg.collector.total_frames) c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) + with timeit("collecting"): data = next(c_iter) @@ -257,10 +260,7 @@ def update(batch): ) actor.train() - if i % 200 == 0: - log_info.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + log_info.update(timeit.todict(prefix="time")) if logger: for key, value in log_info.items(): diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index e74997eb37f..f10f8c73814 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -11,7 +11,6 @@ """ from __future__ import annotations -import time import warnings import hydra @@ -21,7 +20,7 @@ import tqdm from tensordict.nn import CudaGraphModule -from torchrl._utils import logger as torchrl_logger, timeit +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger @@ -156,9 +155,9 @@ def update(data, policy_eval_start, iteration): eval_steps = cfg.logger.eval_steps # Training loop - start_time = time.time() policy_eval_start = torch.tensor(policy_eval_start, device=device) for i in range(gradient_steps): + timeit.printevery(1000, gradient_steps, erase=True) pbar.update(1) # sample data with timeit("sample"): @@ -195,12 +194,8 @@ def update(data, policy_eval_start, iteration): if i % 200 == 0: to_log.update(timeit.todict(prefix="time")) log_metrics(logger, to_log, i) - if i % 200 == 0: - timeit.print() - timeit.erase() pbar.close() - torchrl_logger.info(f"Training time: {time.time() - start_time}") if not eval_env.is_closed: eval_env.close() diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 61a19894ce0..4e7713344f9 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -170,7 +170,9 @@ def update(sampled_tensordict): eval_rollout_steps = cfg.logger.eval_steps c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): tensordict = next(c_iter) pbar.update(tensordict.numel()) @@ -245,9 +247,6 @@ def update(sampled_tensordict): metrics_to_log["eval/reward"] = eval_reward log_metrics(logger, metrics_to_log, collected_frames) - if i % 10 == 0: - timeit.print() - timeit.erase() collector.shutdown() if not eval_env.is_closed: diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index c5a06b4b156..9aab7fa7799 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -227,10 +227,6 @@ def update(sampled_tensordict): if i % 100 == 0: metrics_to_log.update(timeit.todict(prefix="time")) - if i % 100 == 0: - timeit.print() - timeit.erase() - if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index a0068b6662e..4c29e95cb36 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -192,7 +192,9 @@ def update(sampled_tensordict: TensorDict, update_actor: bool): update_counter = 0 delayed_updates = cfg.optim.policy_update_delay c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): torch.compiler.cudagraph_mark_step_begin() tensordict = next(c_iter) @@ -267,9 +269,6 @@ def update(sampled_tensordict: TensorDict, update_actor: bool): if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - if i % 20 == 0: - timeit.print() - timeit.erase() collector.shutdown() if not eval_env.is_closed: diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index c3e3c9eb835..9d06dc2ff75 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -156,7 +156,9 @@ def update(sampled_tensordict): eval_rollout_steps = cfg.env.max_episode_steps c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for _ in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): tensordict = next(c_iter) # Update exploration policy @@ -226,10 +228,7 @@ def update(sampled_tensordict): eval_env.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - if i % 20 == 0: - metrics_to_log.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 6ac058b9843..57ba327b935 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -128,6 +128,7 @@ def update(data: TensorDict) -> TensorDict: # Pretraining pbar = tqdm.tqdm(range(pretrain_gradient_steps)) for i in pbar: + timeit.printevery(1000, pretrain_gradient_steps, erase=True) # Sample data with timeit("rb - sample"): data = offline_buffer.sample().to(model_device) @@ -151,10 +152,7 @@ def update(data: TensorDict) -> TensorDict: to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) - if i % 200 == 0: - to_log.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, to_log, i) diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 9f3ec5f8134..7c6c9968774 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -8,7 +8,6 @@ """ from __future__ import annotations -import time import warnings import hydra @@ -130,8 +129,8 @@ def update(data): torchrl_logger.info(" ***Pretraining*** ") # Pretraining - start_time = time.time() for i in range(pretrain_gradient_steps): + timeit.printevery(1000, pretrain_gradient_steps, erase=True) pbar.update(1) with timeit("sample"): # Sample data @@ -170,10 +169,7 @@ def update(data): eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) - if i % 200 == 0: - to_log.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, to_log, i) @@ -181,7 +177,6 @@ def update(data): pbar.close() if not test_env.is_closed: test_env.close() - torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index c88206e1330..a5dad120a60 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -155,7 +155,9 @@ def update(sampled_tensordict): frames_per_batch = cfg.collector.frames_per_batch c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): collected_data = next(c_iter) @@ -229,10 +231,7 @@ def update(sampled_tensordict): eval_env.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - if i % 50 == 0: - metrics_to_log.update(timeit.todict(prefix="time")) - timeit.print() - timeit.erase() + metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index d43ac25c822..b4236c3e89f 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -173,7 +173,9 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=total_frames) c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): data = next(c_iter) log_info = {} @@ -241,10 +243,7 @@ def update(sampled_tensordict): ) model.train() - if i % 200 == 0: - timeit.print() - log_info.update(timeit.todict(prefix="time")) - timeit.erase() + log_info.update(timeit.todict(prefix="time")) # Log all the information if logger: diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 873cf278d4b..57236337ced 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -156,7 +156,9 @@ def update(sampled_tensordict): q_losses = torch.zeros(num_updates, device=device) c_iter = iter(collector) - for i in range(len(collector)): + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): data = next(c_iter) @@ -226,10 +228,7 @@ def update(sampled_tensordict): } ) - if i % 200 == 0: - timeit.print() - log_info.update(timeit.todict(prefix="time")) - timeit.erase() + log_info.update(timeit.todict(prefix="time")) # Log all the information if logger: diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index a02845cfe4d..45d3acbb85f 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -22,7 +22,7 @@ from ppo_utils import eval_model, make_env, make_ppo_models from tensordict.nn import CudaGraphModule -from torchrl._utils import compile_with_warmup +from torchrl._utils import compile_with_warmup, timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -256,19 +256,28 @@ def update(data, expert_data, num_network_updates=num_network_updates): cfg_logger_test_interval = cfg.logger.test_interval cfg_logger_num_test_episodes = cfg.logger.num_test_episodes - for i, data in enumerate(collector): + total_iter = len(collector) + collector_iter = iter(collector) + for i in range(total_iter): + + timeit.printevery(1000, total_iter, erase=True) + + with timeit("collection"): + data = next(collector_iter) log_info = {} frames_in_batch = data.numel() collected_frames += frames_in_batch pbar.update(data.numel()) - # Update discriminator - # Get expert data - expert_data = replay_buffer.sample() - expert_data = expert_data.to(device) + with timeit("rb - sample expert"): + # Get expert data + expert_data = replay_buffer.sample() + expert_data = expert_data.to(device) - metadata = update(data, expert_data) + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + metadata = update(data, expert_data) d_loss = metadata["dloss"] alpha = metadata["alpha"] @@ -287,8 +296,6 @@ def update(data, expert_data, num_network_updates=num_network_updates): log_info.update( { - # "train/actor_loss": actor_loss.item(), - # "train/critic_loss": critic_loss.item(), "train/discriminator_loss": d_loss["loss"], "train/lr": alpha * cfg_optim_lr, "train/clip_epsilon": ( @@ -300,7 +307,9 @@ def update(data, expert_data, num_network_updates=num_network_updates): ) # evaluation - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( i * frames_in_batch ) // cfg_logger_test_interval: @@ -315,6 +324,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): ) actor.train() if logger is not None: + log_info.update(timeit.todict(prefix="time")) log_metrics(logger, log_info, i) pbar.close() diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 17153a59913..2fd990d0bb3 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -159,7 +159,10 @@ def update(sampled_tensordict): eval_rollout_steps = cfg.collector.max_frames_per_traj collector_iter = iter(collector) - for _ in range(len(collector)): + total_iter = len(collector) + for _ in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) + with timeit("collection"): tensordict = next(collector_iter) current_frames = tensordict.numel() diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 1a270ee8ccc..f31ce04f561 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -133,6 +133,8 @@ def update(data): # Training loop for i in pbar: + timeit.printevery(1000, cfg.optim.gradient_steps, erase=True) + # sample data with timeit("sample"): data = replay_buffer.sample() diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 4f6c765d1e8..28b35099286 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -156,7 +156,10 @@ def update(sampled_tensordict): eval_rollout_steps = cfg.collector.max_frames_per_traj collector_iter = iter(collector) pbar = tqdm.tqdm(range(collector.total_frames)) - for _ in range(len(collector)): + total_iter = len(collector) + for _ in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) + with timeit("collection"): tensordict = next(collector_iter) current_frames = tensordict.numel() diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index f5f0ae41945..8b97f227490 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -178,7 +178,6 @@ def update(batch, num_network_updates): # Update the networks optim.step() return loss.detach().set("alpha", alpha), num_network_updates.clone() - if cfg.compile.compile: update = compile_with_warmup(update, mode=compile_mode, warmup=1) adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1) @@ -203,8 +202,10 @@ def update(batch, num_network_updates): losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches]) collector_iter = iter(collector) + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) - for i in range(len(collector)): with timeit("collecting"): data = next(collector_iter) diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 2252a68b3de..162b8e701df 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -163,7 +163,6 @@ def update(batch, num_network_updates): # Update the networks optim.step() return loss.detach().set("alpha", alpha), num_network_updates.clone() - if cfg.compile.compile: update = compile_with_warmup(update, mode=compile_mode, warmup=1) adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1) @@ -192,8 +191,10 @@ def update(batch, num_network_updates): losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches]) collector_iter = iter(collector) + total_iter = len(collector) + for i in range(total_iter): + timeit.printevery(1000, total_iter, erase=True) - for i in range(len(collector)): with timeit("collecting"): data = next(collector_iter) diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 7a41bf0ab8f..d4c75c85179 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -52,6 +52,7 @@ import torchrl.modules import torchrl.objectives import torchrl.trainers +from torchrl._utils import compile_with_warmup, timeit # Filter warnings in subprocesses: True by default given the multiple optional # deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`. diff --git a/torchrl/_utils.py b/torchrl/_utils.py index cc1621d8723..6a2f80aeffb 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -103,7 +103,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): val[2] = N @staticmethod - def print(prefix=None) -> str: # noqa: T202 + def print(prefix: str = None) -> str: # noqa: T202 """Prints the state of the timer. Returns: @@ -123,6 +123,25 @@ def print(prefix=None) -> str: # noqa: T202 logger.info(string[-1]) return "\n".join(string) + _printevery_count = 0 + + @classmethod + def printevery( + cls, + num_prints: int, + total_count: int, + *, + prefix: str = None, + erase: bool = False, + ) -> None: + """Prints the state of the timer at regular intervals.""" + interval = max(1, total_count // num_prints) + if cls._printevery_count % interval == 0: + cls.print(prefix=prefix) + if erase: + cls.erase() + cls._printevery_count += 1 + @classmethod def todict(cls, percall=True, prefix=None): def _make_key(key):