Skip to content

Commit

Permalink
[Feature] Log pbar rate in SOTA implementations
Browse files Browse the repository at this point in the history
ghstack-source-id: 183a57a0e3630b031448c3bc87d0a1cf49ad73bc
Pull Request resolved: #2662
  • Loading branch information
vmoens committed Dec 17, 2024
1 parent 91064bc commit 46a28e5
Show file tree
Hide file tree
Showing 27 changed files with 151 additions and 137 deletions.
15 changes: 8 additions & 7 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
with timeit("collecting"):
data = next(c_iter)

log_info = {}
metrics_to_log = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch * frame_skip
pbar.update(data.numel())
Expand All @@ -198,7 +198,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "terminated"]]
log_info.update(
metrics_to_log.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
Expand Down Expand Up @@ -242,8 +242,8 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
losses = torch.stack(losses).float().mean()

for key, value in losses.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
metrics_to_log.update({f"train/{key}": value.item()})
metrics_to_log.update(
{
"train/lr": lr * alpha,
}
Expand All @@ -259,15 +259,16 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
test_rewards = eval_model(
actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes
)
log_info.update(
metrics_to_log.update(
{
"test/reward": test_rewards.mean(),
}
)
log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, collected_frames)

collector.shutdown()
Expand Down
18 changes: 8 additions & 10 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def update(batch):
with timeit("collecting"):
data = next(c_iter)

log_info = {}
metrics_to_log = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())
Expand All @@ -195,7 +195,7 @@ def update(batch):
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
metrics_to_log.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
Expand Down Expand Up @@ -236,8 +236,8 @@ def update(batch):
# Get training losses
losses = torch.stack(losses).float().mean()
for key, value in losses.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
metrics_to_log.update({f"train/{key}": value.item()})
metrics_to_log.update(
{
"train/lr": alpha * cfg.optim.lr,
}
Expand All @@ -253,21 +253,19 @@ def update(batch):
test_rewards = eval_model(
actor, test_env, num_episodes=cfg.logger.num_test_episodes
)
log_info.update(
metrics_to_log.update(
{
"test/reward": test_rewards.mean(),
}
)
actor.train()

log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, collected_frames)

torch.compiler.cudagraph_mark_step_begin()

collector.shutdown()
if not test_env.is_closed:
test_env.close()
Expand Down
9 changes: 5 additions & 4 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def update(data, policy_eval_start, iteration):
)

# log metrics
to_log = {
metrics_to_log = {
"loss": loss.cpu(),
**loss_vals.cpu(),
}
Expand All @@ -188,11 +188,12 @@ def update(data, policy_eval_start, iteration):
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward
metrics_to_log["evaluation_reward"] = eval_reward

with timeit("log"):
to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, to_log, i)
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, i)

pbar.close()
if not eval_env.is_closed:
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def update(sampled_tensordict):
"loss_alpha_prime"
).mean()
metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
metrics_to_log.update(timeit.todict(prefix="time"))

# Evaluation
with timeit("eval"):
Expand All @@ -241,6 +240,8 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
metrics_to_log["eval/reward"] = eval_reward

metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,10 @@ def update(sampled_tensordict):
tds = torch.stack(tds, dim=0).mean()
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/cql_loss"] = tds["loss_cql"]
metrics_to_log.update(timeit.todict(prefix="time"))

if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,14 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
metrics_to_log.update(timeit.todict(prefix="time"))
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/actor_loss"] = tds["loss_actor"]
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]

if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,10 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log.update(timeit.todict(prefix="time"))

if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
9 changes: 5 additions & 4 deletions sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def update(data: TensorDict) -> TensorDict:
loss_vals = update(data)
scheduler.step()
# Log metrics
to_log = {"train/loss": loss_vals["loss"]}
metrics_to_log = {"train/loss": loss_vals["loss"]}

# Evaluation
with set_exploration_type(
Expand All @@ -149,13 +149,14 @@ def update(data: TensorDict) -> TensorDict:
auto_cast_to_device=True,
)
test_env.apply(dump_video)
to_log["eval/reward"] = (
metrics_to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)
to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, i)

pbar.close()
if not test_env.is_closed:
Expand Down
10 changes: 5 additions & 5 deletions sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def update(data):
scheduler.step()

# Log metrics
to_log = {
metrics_to_log = {
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"],
"train/loss_entropy": loss_vals["loss_entropy"],
"train/loss_alpha": loss_vals["loss_alpha"],
Expand All @@ -165,14 +165,14 @@ def update(data):
)
test_env.apply(dump_video)
inference_policy.train()
to_log["eval/reward"] = (
metrics_to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)

to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, i)

pbar.close()
if not test_env.is_closed:
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,9 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log.update(timeit.todict(prefix="time"))
if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/dqn/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ env:
# collector
collector:
total_frames: 40_000_100
frames_per_batch: 16
frames_per_batch: 1600
eps_start: 1.0
eps_end: 0.01
annealing_frames: 4_000_000
Expand Down Expand Up @@ -38,7 +38,7 @@ optim:
loss:
gamma: 0.99
hard_update_freq: 10_000
num_updates: 1
num_updates: 100

compile:
compile: False
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/dqn/config_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ env:
# collector
collector:
total_frames: 500_100
frames_per_batch: 10
frames_per_batch: 1000
eps_start: 1.0
eps_end: 0.05
annealing_frames: 250_000
Expand Down Expand Up @@ -37,7 +37,7 @@ optim:
loss:
gamma: 0.99
hard_update_freq: 50
num_updates: 1
num_updates: 100

compile:
compile: False
Expand Down
Loading

0 comments on commit 46a28e5

Please sign in to comment.