Skip to content

Commit

Permalink
[Feature] timeit.printevery
Browse files Browse the repository at this point in the history
ghstack-source-id: 7564a8ef2c0228ce513af0607310700b2fa0b3c0
Pull Request resolved: #2653
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 55cfd14 commit 00352ab
Show file tree
Hide file tree
Showing 20 changed files with 95 additions and 77 deletions.
10 changes: 5 additions & 5 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
lr = cfg.optim.lr

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)

with timeit("collecting"):
data = next(c_iter)

Expand Down Expand Up @@ -261,10 +264,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
"test/reward": test_rewards.mean(),
}
)
if i % 200 == 0:
log_info.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
Expand Down
10 changes: 5 additions & 5 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def update(batch):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)

with timeit("collecting"):
data = next(c_iter)

Expand Down Expand Up @@ -257,10 +260,7 @@ def update(batch):
)
actor.train()

if i % 200 == 0:
log_info.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
Expand Down
9 changes: 2 additions & 7 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
Expand All @@ -21,7 +20,7 @@
import tqdm
from tensordict.nn import CudaGraphModule

from torchrl._utils import logger as torchrl_logger, timeit
from torchrl._utils import timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -156,9 +155,9 @@ def update(data, policy_eval_start, iteration):
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
policy_eval_start = torch.tensor(policy_eval_start, device=device)
for i in range(gradient_steps):
timeit.printevery(1000, gradient_steps, erase=True)
pbar.update(1)
# sample data
with timeit("sample"):
Expand Down Expand Up @@ -195,12 +194,8 @@ def update(data, policy_eval_start, iteration):
if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, to_log, i)
if i % 200 == 0:
timeit.print()
timeit.erase()

pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
if not eval_env.is_closed:
eval_env.close()

Expand Down
7 changes: 3 additions & 4 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def update(sampled_tensordict):
eval_rollout_steps = cfg.logger.eval_steps

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
tensordict = next(c_iter)
pbar.update(tensordict.numel())
Expand Down Expand Up @@ -245,9 +247,6 @@ def update(sampled_tensordict):
metrics_to_log["eval/reward"] = eval_reward

log_metrics(logger, metrics_to_log, collected_frames)
if i % 10 == 0:
timeit.print()
timeit.erase()

collector.shutdown()
if not eval_env.is_closed:
Expand Down
4 changes: 0 additions & 4 deletions sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,6 @@ def update(sampled_tensordict):
if i % 100 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))

if i % 100 == 0:
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)

Expand Down
7 changes: 3 additions & 4 deletions sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
update_counter = 0
delayed_updates = cfg.optim.policy_update_delay
c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
torch.compiler.cudagraph_mark_step_begin()
tensordict = next(c_iter)
Expand Down Expand Up @@ -267,9 +269,6 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
if i % 20 == 0:
timeit.print()
timeit.erase()

collector.shutdown()
if not eval_env.is_closed:
Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def update(sampled_tensordict):
eval_rollout_steps = cfg.env.max_episode_steps

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for _ in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
tensordict = next(c_iter)
# Update exploration policy
Expand Down Expand Up @@ -226,10 +228,7 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
if i % 20 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
metrics_to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
Expand Down
6 changes: 2 additions & 4 deletions sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def update(data: TensorDict) -> TensorDict:
# Pretraining
pbar = tqdm.tqdm(range(pretrain_gradient_steps))
for i in pbar:
timeit.printevery(1000, pretrain_gradient_steps, erase=True)
# Sample data
with timeit("rb - sample"):
data = offline_buffer.sample().to(model_device)
Expand All @@ -151,10 +152,7 @@ def update(data: TensorDict) -> TensorDict:
to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)
if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)
Expand Down
9 changes: 2 additions & 7 deletions sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
Expand Down Expand Up @@ -130,8 +129,8 @@ def update(data):

torchrl_logger.info(" ***Pretraining*** ")
# Pretraining
start_time = time.time()
for i in range(pretrain_gradient_steps):
timeit.printevery(1000, pretrain_gradient_steps, erase=True)
pbar.update(1)
with timeit("sample"):
# Sample data
Expand Down Expand Up @@ -170,18 +169,14 @@ def update(data):
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)

if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)

pbar.close()
if not test_env.is_closed:
test_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def update(sampled_tensordict):
frames_per_batch = cfg.collector.frames_per_batch

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
collected_data = next(c_iter)

Expand Down Expand Up @@ -229,10 +231,7 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
if i % 50 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
metrics_to_log.update(timeit.todict(prefix="time"))
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)

Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=total_frames)

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
data = next(c_iter)
log_info = {}
Expand Down Expand Up @@ -241,10 +243,7 @@ def update(sampled_tensordict):
)
model.train()

if i % 200 == 0:
timeit.print()
log_info.update(timeit.todict(prefix="time"))
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

# Log all the information
if logger:
Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/dqn/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def update(sampled_tensordict):
q_losses = torch.zeros(num_updates, device=device)

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
data = next(c_iter)

Expand Down Expand Up @@ -226,10 +228,7 @@ def update(sampled_tensordict):
}
)

if i % 200 == 0:
timeit.print()
log_info.update(timeit.todict(prefix="time"))
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

# Log all the information
if logger:
Expand Down
30 changes: 20 additions & 10 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ppo_utils import eval_model, make_env, make_ppo_models
from tensordict.nn import CudaGraphModule

from torchrl._utils import compile_with_warmup
from torchrl._utils import compile_with_warmup, timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
Expand Down Expand Up @@ -256,19 +256,28 @@ def update(data, expert_data, num_network_updates=num_network_updates):
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):
total_iter = len(collector)
collector_iter = iter(collector)
for i in range(total_iter):

timeit.printevery(1000, total_iter, erase=True)

with timeit("collection"):
data = next(collector_iter)

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)
with timeit("rb - sample expert"):
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)

metadata = update(data, expert_data)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
metadata = update(data, expert_data)
d_loss = metadata["dloss"]
alpha = metadata["alpha"]

Expand All @@ -287,8 +296,6 @@ def update(data, expert_data, num_network_updates=num_network_updates):

log_info.update(
{
# "train/actor_loss": actor_loss.item(),
# "train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"],
"train/lr": alpha * cfg_optim_lr,
"train/clip_epsilon": (
Expand All @@ -300,7 +307,9 @@ def update(data, expert_data, num_network_updates=num_network_updates):
)

# evaluation
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
with torch.no_grad(), set_exploration_type(
ExplorationType.DETERMINISTIC
), timeit("eval"):
if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
i * frames_in_batch
) // cfg_logger_test_interval:
Expand All @@ -315,6 +324,7 @@ def update(data, expert_data, num_network_updates=num_network_updates):
)
actor.train()
if logger is not None:
log_info.update(timeit.todict(prefix="time"))
log_metrics(logger, log_info, i)

pbar.close()
Expand Down
5 changes: 4 additions & 1 deletion sota-implementations/iql/discrete_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def update(sampled_tensordict):
eval_rollout_steps = cfg.collector.max_frames_per_traj

collector_iter = iter(collector)
for _ in range(len(collector)):
total_iter = len(collector)
for _ in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)

with timeit("collection"):
tensordict = next(collector_iter)
current_frames = tensordict.numel()
Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/iql/iql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def update(data):

# Training loop
for i in pbar:
timeit.printevery(1000, cfg.optim.gradient_steps, erase=True)

# sample data
with timeit("sample"):
data = replay_buffer.sample()
Expand Down
5 changes: 4 additions & 1 deletion sota-implementations/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def update(sampled_tensordict):
eval_rollout_steps = cfg.collector.max_frames_per_traj
collector_iter = iter(collector)
pbar = tqdm.tqdm(range(collector.total_frames))
for _ in range(len(collector)):
total_iter = len(collector)
for _ in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)

with timeit("collection"):
tensordict = next(collector_iter)
current_frames = tensordict.numel()
Expand Down
Loading

0 comments on commit 00352ab

Please sign in to comment.