Skip to content

Commit

Permalink
[Feature] Log pbar rate in SOTA implementations
Browse files Browse the repository at this point in the history
ghstack-source-id: 11dff37f598411133c4d6e61f4c760bd5abf6a08
Pull Request resolved: #2662
  • Loading branch information
vmoens committed Dec 17, 2024
1 parent 91064bc commit 1d5545f
Show file tree
Hide file tree
Showing 25 changed files with 131 additions and 118 deletions.
15 changes: 8 additions & 7 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
with timeit("collecting"):
data = next(c_iter)

log_info = {}
metrics_to_log = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch * frame_skip
pbar.update(data.numel())
Expand All @@ -198,7 +198,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "terminated"]]
log_info.update(
metrics_to_log.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
Expand Down Expand Up @@ -242,8 +242,8 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
losses = torch.stack(losses).float().mean()

for key, value in losses.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
metrics_to_log.update({f"train/{key}": value.item()})
metrics_to_log.update(
{
"train/lr": lr * alpha,
}
Expand All @@ -259,15 +259,16 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
test_rewards = eval_model(
actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes
)
log_info.update(
metrics_to_log.update(
{
"test/reward": test_rewards.mean(),
}
)
log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, collected_frames)

collector.shutdown()
Expand Down
18 changes: 8 additions & 10 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def update(batch):
with timeit("collecting"):
data = next(c_iter)

log_info = {}
metrics_to_log = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())
Expand All @@ -195,7 +195,7 @@ def update(batch):
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
metrics_to_log.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
Expand Down Expand Up @@ -236,8 +236,8 @@ def update(batch):
# Get training losses
losses = torch.stack(losses).float().mean()
for key, value in losses.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
metrics_to_log.update({f"train/{key}": value.item()})
metrics_to_log.update(
{
"train/lr": alpha * cfg.optim.lr,
}
Expand All @@ -253,21 +253,19 @@ def update(batch):
test_rewards = eval_model(
actor, test_env, num_episodes=cfg.logger.num_test_episodes
)
log_info.update(
metrics_to_log.update(
{
"test/reward": test_rewards.mean(),
}
)
actor.train()

log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, collected_frames)

torch.compiler.cudagraph_mark_step_begin()

collector.shutdown()
if not test_env.is_closed:
test_env.close()
Expand Down
9 changes: 5 additions & 4 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def update(data, policy_eval_start, iteration):
)

# log metrics
to_log = {
metrics_to_log = {
"loss": loss.cpu(),
**loss_vals.cpu(),
}
Expand All @@ -188,11 +188,12 @@ def update(data, policy_eval_start, iteration):
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward
metrics_to_log["evaluation_reward"] = eval_reward

with timeit("log"):
to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, to_log, i)
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, i)

pbar.close()
if not eval_env.is_closed:
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def update(sampled_tensordict):
"loss_alpha_prime"
).mean()
metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
metrics_to_log.update(timeit.todict(prefix="time"))

# Evaluation
with timeit("eval"):
Expand All @@ -241,6 +240,8 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
metrics_to_log["eval/reward"] = eval_reward

metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,10 @@ def update(sampled_tensordict):
tds = torch.stack(tds, dim=0).mean()
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/cql_loss"] = tds["loss_cql"]
metrics_to_log.update(timeit.todict(prefix="time"))

if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,14 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
metrics_to_log.update(timeit.todict(prefix="time"))
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/actor_loss"] = tds["loss_actor"]
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]

if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,10 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log.update(timeit.todict(prefix="time"))

if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
9 changes: 5 additions & 4 deletions sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def update(data: TensorDict) -> TensorDict:
loss_vals = update(data)
scheduler.step()
# Log metrics
to_log = {"train/loss": loss_vals["loss"]}
metrics_to_log = {"train/loss": loss_vals["loss"]}

# Evaluation
with set_exploration_type(
Expand All @@ -149,13 +149,14 @@ def update(data: TensorDict) -> TensorDict:
auto_cast_to_device=True,
)
test_env.apply(dump_video)
to_log["eval/reward"] = (
metrics_to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)
to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, i)

pbar.close()
if not test_env.is_closed:
Expand Down
10 changes: 5 additions & 5 deletions sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def update(data):
scheduler.step()

# Log metrics
to_log = {
metrics_to_log = {
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"],
"train/loss_entropy": loss_vals["loss_entropy"],
"train/loss_alpha": loss_vals["loss_alpha"],
Expand All @@ -165,14 +165,14 @@ def update(data):
)
test_env.apply(dump_video)
inference_policy.train()
to_log["eval/reward"] = (
metrics_to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)

to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, i)

pbar.close()
if not test_env.is_closed:
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,9 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log.update(timeit.todict(prefix="time"))
if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
18 changes: 9 additions & 9 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main(cfg: "DictConfig"): # noqa: F821
delay_value=True,
)
loss_module.set_keys(done="end-of-life", terminated="end-of-life")
loss_module.make_value_estimator(gamma=cfg.loss.gamma)
loss_module.make_value_estimator(gamma=cfg.loss.gamma, device=device)
target_net_updater = HardUpdate(
loss_module, value_network_update_interval=cfg.loss.hard_update_freq
)
Expand Down Expand Up @@ -178,7 +178,7 @@ def update(sampled_tensordict):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
data = next(c_iter)
log_info = {}
metrics_to_log = {}
pbar.update(data.numel())
data = data.reshape(-1)
current_frames = data.numel() * frame_skip
Expand All @@ -193,7 +193,7 @@ def update(sampled_tensordict):
episode_reward_mean = episode_rewards.mean().item()
episode_length = data["next", "step_count"][data["next", "done"]]
episode_length_mean = episode_length.sum().item() / len(episode_length)
log_info.update(
metrics_to_log.update(
{
"train/episode_reward": episode_reward_mean,
"train/episode_length": episode_length_mean,
Expand All @@ -202,7 +202,7 @@ def update(sampled_tensordict):

if collected_frames < init_random_frames:
if logger:
for key, value in log_info.items():
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, step=collected_frames)
continue

Expand All @@ -216,7 +216,7 @@ def update(sampled_tensordict):
q_losses[j].copy_(q_loss)

# Get and log q-values, loss, epsilon, sampling time and training time
log_info.update(
metrics_to_log.update(
{
"train/q_values": data["chosen_action_value"].sum() / frames_per_batch,
"train/q_loss": q_losses.mean(),
Expand All @@ -236,18 +236,18 @@ def update(sampled_tensordict):
test_rewards = eval_model(
model, test_env, num_episodes=num_test_episodes
)
log_info.update(
metrics_to_log.update(
{
"eval/reward": test_rewards,
}
)
model.train()

log_info.update(timeit.todict(prefix="time"))

# Log all the information
if logger:
for key, value in log_info.items():
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, step=collected_frames)

# update weights of the inference policy
Expand Down
18 changes: 9 additions & 9 deletions sota-implementations/dqn/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def main(cfg: "DictConfig"): # noqa: F821
loss_function="l2",
delay_value=True,
)
loss_module.make_value_estimator(gamma=cfg.loss.gamma)
loss_module.make_value_estimator(gamma=cfg.loss.gamma, device=device)
loss_module = loss_module.to(device)
target_net_updater = HardUpdate(
loss_module, value_network_update_interval=cfg.loss.hard_update_freq
Expand Down Expand Up @@ -162,7 +162,7 @@ def update(sampled_tensordict):
with timeit("collecting"):
data = next(c_iter)

log_info = {}
metrics_to_log = {}
pbar.update(data.numel())
data = data.reshape(-1)
current_frames = data.numel()
Expand All @@ -178,7 +178,7 @@ def update(sampled_tensordict):
episode_reward_mean = episode_rewards.mean().item()
episode_length = data["next", "step_count"][data["next", "done"]]
episode_length_mean = episode_length.sum().item() / len(episode_length)
log_info.update(
metrics_to_log.update(
{
"train/episode_reward": episode_reward_mean,
"train/episode_length": episode_length_mean,
Expand All @@ -188,7 +188,7 @@ def update(sampled_tensordict):
if collected_frames < init_random_frames:
if collected_frames < init_random_frames:
if logger:
for key, value in log_info.items():
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, step=collected_frames)
continue

Expand All @@ -202,7 +202,7 @@ def update(sampled_tensordict):
q_losses[j].copy_(q_loss)

# Get and log q-values, loss, epsilon, sampling time and training time
log_info.update(
metrics_to_log.update(
{
"train/q_values": (data["action_value"] * data["action"]).sum().item()
/ frames_per_batch,
Expand All @@ -222,17 +222,17 @@ def update(sampled_tensordict):
model.eval()
test_rewards = eval_model(model, test_env, num_test_episodes)
model.train()
log_info.update(
metrics_to_log.update(
{
"eval/reward": test_rewards,
}
)

log_info.update(timeit.todict(prefix="time"))

# Log all the information
if logger:
for key, value in log_info.items():
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, step=collected_frames)

# update weights of the inference policy
Expand Down
Loading

0 comments on commit 1d5545f

Please sign in to comment.