diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 47e43125ea4..3279d6e0a2b 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -189,7 +189,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): with timeit("collecting"): data = next(c_iter) - log_info = {} + metrics_to_log = {} frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip pbar.update(data.numel()) @@ -198,7 +198,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "terminated"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -242,8 +242,8 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): losses = torch.stack(losses).float().mean() for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": lr * alpha, } @@ -259,15 +259,16 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): test_rewards = eval_model( actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes ) - log_info.update( + metrics_to_log.update( { "test/reward": test_rewards.mean(), } ) - log_info.update(timeit.todict(prefix="time")) if logger: - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.shutdown() diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 07ad5197954..41e05dc1326 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -186,7 +186,7 @@ def update(batch): with timeit("collecting"): data = next(c_iter) - log_info = {} + metrics_to_log = {} frames_in_batch = data.numel() collected_frames += frames_in_batch pbar.update(data.numel()) @@ -195,7 +195,7 @@ def update(batch): episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -236,8 +236,8 @@ def update(batch): # Get training losses losses = torch.stack(losses).float().mean() for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": alpha * cfg.optim.lr, } @@ -253,21 +253,19 @@ def update(batch): test_rewards = eval_model( actor, test_env, num_episodes=cfg.logger.num_test_episodes ) - log_info.update( + metrics_to_log.update( { "test/reward": test_rewards.mean(), } ) actor.train() - log_info.update(timeit.todict(prefix="time")) - if logger: - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) - torch.compiler.cudagraph_mark_step_begin() - collector.shutdown() if not test_env.is_closed: test_env.close() diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index c0030a1e9cc..2e1a20ad7a2 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -172,7 +172,7 @@ def update(data, policy_eval_start, iteration): ) # log metrics - to_log = { + metrics_to_log = { "loss": loss.cpu(), **loss_vals.cpu(), } @@ -188,11 +188,12 @@ def update(data, policy_eval_start, iteration): ) eval_env.apply(dump_video) eval_reward = eval_td["next", "reward"].sum(1).mean().item() - to_log["evaluation_reward"] = eval_reward + metrics_to_log["evaluation_reward"] = eval_reward with timeit("log"): - to_log.update(timeit.todict(prefix="time")) - log_metrics(logger, to_log, i) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, i) pbar.close() if not eval_env.is_closed: diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 03bdf6a493f..e992bdb5939 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -220,7 +220,6 @@ def update(sampled_tensordict): "loss_alpha_prime" ).mean() metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean() - metrics_to_log.update(timeit.todict(prefix="time")) # Evaluation with timeit("eval"): @@ -241,6 +240,8 @@ def update(sampled_tensordict): eval_env.apply(dump_video) metrics_to_log["eval/reward"] = eval_reward + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index 35238c5c6ab..efb54ea3f73 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -222,9 +222,10 @@ def update(sampled_tensordict): tds = torch.stack(tds, dim=0).mean() metrics_to_log["train/q_loss"] = tds["loss_qvalue"] metrics_to_log["train/cql_loss"] = tds["loss_cql"] - metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index 07de3e26175..d84613e6876 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -256,13 +256,14 @@ def update(sampled_tensordict: TensorDict, update_actor: bool): metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( episode_length ) - metrics_to_log.update(timeit.todict(prefix="time")) if collected_frames >= init_random_frames: metrics_to_log["train/q_loss"] = tds["loss_qvalue"] metrics_to_log["train/actor_loss"] = tds["loss_actor"] metrics_to_log["train/alpha_loss"] = tds["loss_alpha"] if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index 6e2a749c3f1..bcb7ee6ef54 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -224,9 +224,10 @@ def update(sampled_tensordict): eval_env.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 57ba327b935..8a9eb0c0985 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -136,7 +136,7 @@ def update(data: TensorDict) -> TensorDict: loss_vals = update(data) scheduler.step() # Log metrics - to_log = {"train/loss": loss_vals["loss"]} + metrics_to_log = {"train/loss": loss_vals["loss"]} # Evaluation with set_exploration_type( @@ -149,13 +149,14 @@ def update(data: TensorDict) -> TensorDict: auto_cast_to_device=True, ) test_env.apply(dump_video) - to_log["eval/reward"] = ( + metrics_to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) - to_log.update(timeit.todict(prefix="time")) if logger is not None: - log_metrics(logger, to_log, i) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, i) pbar.close() if not test_env.is_closed: diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 7c6c9968774..1404cb7ebc0 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -143,7 +143,7 @@ def update(data): scheduler.step() # Log metrics - to_log = { + metrics_to_log = { "train/loss_log_likelihood": loss_vals["loss_log_likelihood"], "train/loss_entropy": loss_vals["loss_entropy"], "train/loss_alpha": loss_vals["loss_alpha"], @@ -165,14 +165,14 @@ def update(data): ) test_env.apply(dump_video) inference_policy.train() - to_log["eval/reward"] = ( + metrics_to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) - to_log.update(timeit.todict(prefix="time")) - if logger is not None: - log_metrics(logger, to_log, i) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, i) pbar.close() if not test_env.is_closed: diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index b7910c4e578..9ff50902887 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -227,8 +227,9 @@ def update(sampled_tensordict): eval_env.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index b4236c3e89f..f6bcf3044cb 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -86,7 +86,7 @@ def main(cfg: "DictConfig"): # noqa: F821 delay_value=True, ) loss_module.set_keys(done="end-of-life", terminated="end-of-life") - loss_module.make_value_estimator(gamma=cfg.loss.gamma) + loss_module.make_value_estimator(gamma=cfg.loss.gamma, device=device) target_net_updater = HardUpdate( loss_module, value_network_update_interval=cfg.loss.hard_update_freq ) @@ -178,7 +178,7 @@ def update(sampled_tensordict): timeit.printevery(1000, total_iter, erase=True) with timeit("collecting"): data = next(c_iter) - log_info = {} + metrics_to_log = {} pbar.update(data.numel()) data = data.reshape(-1) current_frames = data.numel() * frame_skip @@ -193,7 +193,7 @@ def update(sampled_tensordict): episode_reward_mean = episode_rewards.mean().item() episode_length = data["next", "step_count"][data["next", "done"]] episode_length_mean = episode_length.sum().item() / len(episode_length) - log_info.update( + metrics_to_log.update( { "train/episode_reward": episode_reward_mean, "train/episode_length": episode_length_mean, @@ -202,7 +202,7 @@ def update(sampled_tensordict): if collected_frames < init_random_frames: if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, step=collected_frames) continue @@ -216,7 +216,7 @@ def update(sampled_tensordict): q_losses[j].copy_(q_loss) # Get and log q-values, loss, epsilon, sampling time and training time - log_info.update( + metrics_to_log.update( { "train/q_values": data["chosen_action_value"].sum() / frames_per_batch, "train/q_loss": q_losses.mean(), @@ -236,18 +236,18 @@ def update(sampled_tensordict): test_rewards = eval_model( model, test_env, num_episodes=num_test_episodes ) - log_info.update( + metrics_to_log.update( { "eval/reward": test_rewards, } ) model.train() - log_info.update(timeit.todict(prefix="time")) - # Log all the information if logger: - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, step=collected_frames) # update weights of the inference policy diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 57236337ced..69689dd4c92 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -69,7 +69,7 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_function="l2", delay_value=True, ) - loss_module.make_value_estimator(gamma=cfg.loss.gamma) + loss_module.make_value_estimator(gamma=cfg.loss.gamma, device=device) loss_module = loss_module.to(device) target_net_updater = HardUpdate( loss_module, value_network_update_interval=cfg.loss.hard_update_freq @@ -162,7 +162,7 @@ def update(sampled_tensordict): with timeit("collecting"): data = next(c_iter) - log_info = {} + metrics_to_log = {} pbar.update(data.numel()) data = data.reshape(-1) current_frames = data.numel() @@ -178,7 +178,7 @@ def update(sampled_tensordict): episode_reward_mean = episode_rewards.mean().item() episode_length = data["next", "step_count"][data["next", "done"]] episode_length_mean = episode_length.sum().item() / len(episode_length) - log_info.update( + metrics_to_log.update( { "train/episode_reward": episode_reward_mean, "train/episode_length": episode_length_mean, @@ -188,7 +188,7 @@ def update(sampled_tensordict): if collected_frames < init_random_frames: if collected_frames < init_random_frames: if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, step=collected_frames) continue @@ -202,7 +202,7 @@ def update(sampled_tensordict): q_losses[j].copy_(q_loss) # Get and log q-values, loss, epsilon, sampling time and training time - log_info.update( + metrics_to_log.update( { "train/q_values": (data["action_value"] * data["action"]).sum().item() / frames_per_batch, @@ -222,17 +222,17 @@ def update(sampled_tensordict): model.eval() test_rewards = eval_model(model, test_env, num_test_episodes) model.train() - log_info.update( + metrics_to_log.update( { "eval/reward": test_rewards, } ) - log_info.update(timeit.todict(prefix="time")) - # Log all the information if logger: - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, step=collected_frames) # update weights of the inference policy diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 45d3acbb85f..bdb8843aaf6 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -265,7 +265,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): with timeit("collection"): data = next(collector_iter) - log_info = {} + metrics_to_log = {} frames_in_batch = data.numel() collected_frames += frames_in_batch pbar.update(data.numel()) @@ -286,7 +286,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -294,7 +294,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): } ) - log_info.update( + metrics_to_log.update( { "train/discriminator_loss": d_loss["loss"], "train/lr": alpha * cfg_optim_lr, @@ -317,15 +317,16 @@ def update(data, expert_data, num_network_updates=num_network_updates): test_rewards = eval_model( actor, test_env, num_episodes=cfg_logger_num_test_episodes ) - log_info.update( + metrics_to_log.update( { "eval/reward": test_rewards.mean(), } ) actor.train() if logger is not None: - log_info.update(timeit.todict(prefix="time")) - log_metrics(logger, log_info, i) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, i) pbar.close() diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index b2b724f6a6d..dcf908c2cd2 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -165,7 +165,7 @@ def main(cfg: "DictConfig"): # noqa: F821 start_time = sampling_start = time.time() for i, data in enumerate(collector): - log_info = {} + metrics_to_log = {} sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip @@ -175,7 +175,7 @@ def main(cfg: "DictConfig"): # noqa: F821 episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "terminated"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821 if len(accumulator) < batch_size: accumulator.append(data) if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) continue @@ -243,8 +243,8 @@ def main(cfg: "DictConfig"): # noqa: F821 training_time = time.time() - training_start losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": alpha * lr, "train/sampling_time": sampling_time, @@ -263,7 +263,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, test_env, num_episodes=num_test_episodes ) eval_time = time.time() - eval_start - log_info.update( + metrics_to_log.update( { "eval/reward": test_reward, "eval/time": eval_time, @@ -272,7 +272,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor.train() if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index 07d38604391..4d90e9053bd 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -157,7 +157,7 @@ def main(cfg: "DictConfig"): # noqa: F821 start_time = sampling_start = time.time() for i, data in enumerate(collector): - log_info = {} + metrics_to_log = {} sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip @@ -167,7 +167,7 @@ def main(cfg: "DictConfig"): # noqa: F821 episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -178,7 +178,7 @@ def main(cfg: "DictConfig"): # noqa: F821 if len(accumulator) < batch_size: accumulator.append(data) if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) continue @@ -235,8 +235,8 @@ def main(cfg: "DictConfig"): # noqa: F821 training_time = time.time() - training_start losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": alpha * lr, "train/sampling_time": sampling_time, @@ -255,7 +255,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, test_env, num_episodes=num_test_episodes ) eval_time = time.time() - eval_start - log_info.update( + metrics_to_log.update( { "eval/reward": test_reward, "eval/time": eval_time, @@ -264,7 +264,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor.train() if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index cd11ae467c3..cda63ac0919 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -134,7 +134,7 @@ def main(cfg: "DictConfig"): # noqa: F821 start_time = sampling_start = time.time() for i, data in enumerate(collector): - log_info = {} + metrics_to_log = {} sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip @@ -144,7 +144,7 @@ def main(cfg: "DictConfig"): # noqa: F821 episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "terminated"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -155,7 +155,7 @@ def main(cfg: "DictConfig"): # noqa: F821 if len(accumulator) < batch_size: accumulator.append(data) if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) continue @@ -212,8 +212,8 @@ def main(cfg: "DictConfig"): # noqa: F821 training_time = time.time() - training_start losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": alpha * lr, "train/sampling_time": sampling_time, @@ -232,7 +232,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, test_env, num_episodes=num_test_episodes ) eval_time = time.time() - eval_start - log_info.update( + metrics_to_log.update( { "eval/reward": test_reward, "eval/time": eval_time, @@ -241,7 +241,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor.train() if logger: - for key, value in log_info.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index e56661acf0c..aa4cea04024 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -226,8 +226,9 @@ def update(sampled_tensordict): metrics_to_log["train/q_loss"] = metadata["q_loss"] metrics_to_log["train/actor_loss"] = metadata["actor_loss"] metrics_to_log["train/value_loss"] = metadata["value_loss"] - metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 00f4cb24f5a..eaf791438cc 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -145,7 +145,7 @@ def update(data): loss_info = update(data) # evaluation - to_log = loss_info.to_dict() + metrics_to_log = loss_info.to_dict() if i % evaluation_interval == 0: with set_exploration_type( ExplorationType.DETERMINISTIC @@ -155,10 +155,11 @@ def update(data): ) eval_env.apply(dump_video) eval_reward = eval_td["next", "reward"].sum(1).mean().item() - to_log["evaluation_reward"] = eval_reward + metrics_to_log["evaluation_reward"] = eval_reward if logger is not None: - to_log.update(timeit.todict(prefix="time")) - log_metrics(logger, to_log, i) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, i) pbar.close() if not eval_env.is_closed: diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 7ec2a30dfd9..5b90f00c467 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -215,9 +215,10 @@ def update(sampled_tensordict): metrics_to_log["train/actor_loss"] = loss_info["loss_actor"] metrics_to_log["train/value_loss"] = loss_info["loss_value"] metrics_to_log["train/entropy"] = loss_info.get("entropy") - metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/multiagent/utils/logging.py b/sota-implementations/multiagent/utils/logging.py index e19ae8d78f7..40c9b70d578 100644 --- a/sota-implementations/multiagent/utils/logging.py +++ b/sota-implementations/multiagent/utils/logging.py @@ -56,13 +56,13 @@ def log_training( .unsqueeze(-1), ) - to_log = { + metrics_to_log = { f"train/learner/{key}": value.mean().item() for key, value in training_td.items() } if "info" in sampling_td.get("agents").keys(): - to_log.update( + metrics_to_log.update( { f"train/info/{key}": value.mean().item() for key, value in sampling_td.get(("agents", "info")).items() @@ -76,7 +76,7 @@ def log_training( episode_reward = sampling_td.get(("next", "agents", "episode_reward")).mean(-2)[ done ] - to_log.update( + metrics_to_log.update( { "train/reward/reward_min": reward.min().item(), "train/reward/reward_mean": reward.mean().item(), @@ -94,12 +94,12 @@ def log_training( } ) if isinstance(logger, WandbLogger): - logger.experiment.log(to_log, commit=False) + logger.experiment.log(metrics_to_log, commit=False) else: - for key, value in to_log.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key.replace("/", "_"), value, step=step) - return to_log + return metrics_to_log def log_evaluation( @@ -121,7 +121,7 @@ def log_evaluation( rollouts[k] = r[: done_index + 1] rewards = [td.get(("next", "agents", "reward")).sum(0).mean() for td in rollouts] - to_log = { + metrics_to_log = { "eval/episode_reward_min": min(rewards), "eval/episode_reward_max": max(rewards), "eval/episode_reward_mean": sum(rewards) / len(rollouts), @@ -138,7 +138,7 @@ def log_evaluation( if isinstance(logger, WandbLogger): import wandb - logger.experiment.log(to_log, commit=False) + logger.experiment.log(metrics_to_log, commit=False) logger.experiment.log( { "eval/video": wandb.Video(vid, fps=1 / env_test.world.dt, format="mp4"), @@ -146,6 +146,6 @@ def log_evaluation( commit=False, ) else: - for key, value in to_log.items(): + for key, value in metrics_to_log.items(): logger.log_scalar(key.replace("/", "_"), value, step=step) logger.log_video("eval_video", vid, step=step) diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 153a1ad9515..f5adba05d7b 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -67,12 +67,11 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, "cpu"), + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, device=device, - storing_device=device, max_frames_per_traj=-1, compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False, cudagraph_policy=cfg.compile.cudagraphs, @@ -214,7 +213,7 @@ def update(batch, num_network_updates): with timeit("collecting"): data = next(collector_iter) - log_info = {} + metrics_to_log = {} frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip pbar.update(frames_in_batch) @@ -223,7 +222,7 @@ def update(batch, num_network_updates): episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "terminated"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -259,8 +258,8 @@ def update(batch, num_network_updates): # Get training losses and times losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses_mean.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": loss["alpha"] * cfg_optim_lr, "train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon, @@ -278,15 +277,16 @@ def update(batch, num_network_updates): test_rewards = eval_model( actor, test_env, num_episodes=cfg_logger_num_test_episodes ) - log_info.update( + metrics_to_log.update( { "eval/reward": test_rewards.mean(), } ) actor.train() if logger: - log_info.update(timeit.todict(prefix="time")) - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index f8568be56e6..f63ef6022fe 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -72,7 +72,6 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, device=device, - storing_device=device, max_frames_per_traj=-1, compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False, cudagraph_policy=cfg.compile.cudagraphs, @@ -207,7 +206,7 @@ def update(batch, num_network_updates): with timeit("collecting"): data = next(collector_iter) - log_info = {} + metrics_to_log = {} frames_in_batch = data.numel() collected_frames += frames_in_batch pbar.update(frames_in_batch) @@ -216,7 +215,7 @@ def update(batch, num_network_updates): episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( + metrics_to_log.update( { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() @@ -253,8 +252,8 @@ def update(batch, num_network_updates): # Get training losses and times losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses_mean.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( + metrics_to_log.update({f"train/{key}": value.item()}) + metrics_to_log.update( { "train/lr": loss["alpha"] * cfg_optim_lr, "train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon @@ -274,7 +273,7 @@ def update(batch, num_network_updates): test_rewards = eval_model( actor, test_env, num_episodes=cfg_logger_num_test_episodes ) - log_info.update( + metrics_to_log.update( { "eval/reward": test_rewards.mean(), } @@ -282,8 +281,9 @@ def update(batch, num_network_updates): actor.train() if logger: - log_info.update(timeit.todict(prefix="time")) - for key, value in log_info.items(): + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + for key, value in metrics_to_log.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index a1ec631fe39..e159824f9cd 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -230,6 +230,7 @@ def update(sampled_tensordict): metrics_to_log["eval/reward"] = eval_reward if logger is not None: metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index bcbe6b879da..3a741735a1c 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -247,6 +247,7 @@ def update(sampled_tensordict, update_actor, prb=prb): metrics_to_log["eval/reward"] = eval_reward if logger is not None: metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] log_metrics(logger, metrics_to_log, collected_frames) collector.shutdown() diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index 35563777962..ac65f2875cf 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -151,11 +151,11 @@ def update(sampled_tensordict, update_actor): torch.compiler.cudagraph_mark_step_begin() metadata = update(sampled_tensordict, update_actor).clone() - to_log = {} + metrics_to_log = {} if update_actor: - to_log.update(metadata.to_dict()) + metrics_to_log.update(metadata.to_dict()) else: - to_log.update(metadata.exclude("actor_loss").to_dict()) + metrics_to_log.update(metadata.exclude("actor_loss").to_dict()) # evaluation if update_counter % evaluation_interval == 0: @@ -167,10 +167,11 @@ def update(sampled_tensordict, update_actor): ) eval_env.apply(dump_video) eval_reward = eval_td["next", "reward"].sum(1).mean().item() - to_log["evaluation_reward"] = eval_reward + metrics_to_log["evaluation_reward"] = eval_reward if logger is not None: - to_log.update(timeit.todict(prefix="time")) - log_metrics(logger, to_log, update_counter) + metrics_to_log.update(timeit.todict(prefix="time")) + metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, update_counter) if not eval_env.is_closed: eval_env.close()