Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Oct 3, 2024
1 parent 30940f9 commit 39df9cc
Showing 1 changed file with 88 additions and 46 deletions.
134 changes: 88 additions & 46 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import json
import os
import warnings
from pathlib import Path

from typing import Dict, List, Optional

import numpy as np
Expand Down Expand Up @@ -93,26 +95,12 @@ def log_collection(
to_log = {}
json_metrics = {}
for group in self.group_map.keys():
episode_reward = self._get_episode_reward(group, batch)
done = self._get_done(group, batch)
reward = self._get_reward(group, batch)
to_log.update(
{
f"collection/{group}/reward/reward_min": reward.min().item(),
f"collection/{group}/reward/reward_mean": reward.mean().item(),
f"collection/{group}/reward/reward_max": reward.max().item(),
}
done = self._get_global_done(group, batch) # Does not have agent dim
group_episode_rewards = self._log_individual_and_group_rewards(
group, batch, done, to_log, prefix="collection"
)
json_metrics[group + "_return"] = episode_reward.mean(-2)[done.any(-2)]
episode_reward = episode_reward[done]
if episode_reward.numel() > 0:
to_log.update(
{
f"collection/{group}/reward/episode_reward_min": episode_reward.min().item(),
f"collection/{group}/reward/episode_reward_mean": episode_reward.mean().item(),
f"collection/{group}/reward/episode_reward_max": episode_reward.max().item(),
}
)
json_metrics[group + "_return"] = group_episode_rewards

if "info" in batch.get(("next", group)).keys():
to_log.update(
{
Expand All @@ -130,19 +118,12 @@ def log_collection(
}
)
to_log.update(task.log_info(batch))
mean_group_return = torch.stack(
[value for key, value in json_metrics.items()], dim=0
).mean(0)
if mean_group_return.numel() > 0:
to_log.update(
{
"collection/reward/episode_reward_min": mean_group_return.min().item(),
"collection/reward/episode_reward_mean": mean_group_return.mean().item(),
"collection/reward/episode_reward_max": mean_group_return.max().item(),
}
)
global_episode_rewards = self._log_global_episode_reward(
list(json_metrics.values()), to_log, prefix="collection"
)

self.log(to_log, step=step)
return mean_group_return.mean().item()
return global_episode_rewards.mean().item()

def log_training(self, group: str, training_td: TensorDictBase, step: int):
if not len(self.loggers):
Expand Down Expand Up @@ -171,7 +152,7 @@ def log_evaluation(
# Cut the rollouts at the first done
rollouts_group = []
for i, r in enumerate(rollouts):
next_done = self._get_done(group, r)
next_done = self._get_agents_done(group, r)
# Reduce it to batch size
next_done = next_done.sum(
tuple(range(r.batch_dims, next_done.ndim)),
Expand Down Expand Up @@ -265,33 +246,94 @@ def finish(self):
def _get_reward(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
if ("next", group, "reward") not in td.keys(True, True):
reward = (
td.get(("next", "reward")).expand(td.get(group).shape).unsqueeze(-1)
)
else:
reward = td.get(("next", group, "reward"))
reward = td.get(("next", group, "reward"))
return reward.mean(-2) if remove_agent_dim else reward

def _get_done(self, group: str, td: TensorDictBase, remove_agent_dim: bool = False):
def _get_agents_done(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
if ("next", group, "done") not in td.keys(True, True):
done = td.get(("next", "done")).expand(td.get(group).shape).unsqueeze(-1)
else:
done = td.get(("next", group, "done"))
return done.any(-2) if remove_agent_dim else done

def _get_global_done(
self,
group: str,
td: TensorDictBase,
):
done = td.get(("next", "done"))
return done

def _get_episode_reward(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
if ("next", group, "episode_reward") not in td.keys(True, True):
episode_reward = (
td.get(("next", "episode_reward"))
.expand(td.get(group).shape)
.unsqueeze(-1)
episode_reward = td.get(("next", group, "episode_reward"))
return episode_reward.mean(-2) if remove_agent_dim else episode_reward

def _log_individual_and_group_rewards(
self, group, batch, global_done, to_log, prefix: str
):
reward = self._get_reward(group, batch) # Has agent dim
episode_reward = self._get_episode_reward(group, batch) # Has agent dim
n_agents_in_group = episode_reward.shape[-2]

# The trajectories are considered until the global done
episode_reward = episode_reward[
global_done.unsqueeze(-1).expand((*batch.get_item_shape(group), 1))
]

for i in range(n_agents_in_group):
to_log.update(
{
f"{prefix}/{group}/reward/agent_{i}/reward_min": reward[..., i, :]
.min()
.item(),
f"{prefix}/{group}/reward/agent_{i}/reward_mean": reward[..., i, :]
.mean()
.item(),
f"{prefix}/{group}/reward/agent_{i}/reward_max": reward[..., i, :]
.max()
.item(),
}
)
if episode_reward.numel() > 0:

to_log.update(
{
f"{prefix}/{group}/reward/episode_reward_min": episode_reward.min().item(),
f"{prefix}/{group}/reward/episode_reward_mean": episode_reward.mean().item(),
f"{prefix}/{group}/reward/episode_reward_max": episode_reward.max().item(),
}
)
to_log.update(
{
f"{prefix}/{group}/reward/reward_min": reward.min().item(),
f"{prefix}/{group}/reward/reward_mean": reward.mean().item(),
f"{prefix}/{group}/reward/reward_max": reward.max().item(),
}
)
return episode_reward

def _log_global_episode_reward(self, episode_rewards, to_log, prefix: str):
# Each element in the list is the episode reward for the group at the global done, so they will have same shape as done is shared
episode_rewards = torch.stack(episode_rewards, dim=0)
if episode_rewards.numel() > 0:
to_log.update(
{
f"{prefix}/reward/episode_reward_min": episode_rewards.min().item(),
f"{prefix}/reward/episode_reward_mean": episode_rewards.mean().item(),
f"{prefix}/reward/episode_reward_max": episode_rewards.max().item(),
}
)
else:
episode_reward = td.get(("next", group, "episode_reward"))
return episode_reward.mean(-2) if remove_agent_dim else episode_reward
warnings.warn(
"No episode terminated this iteration and thus the overall episode reward in NaN, "
"this is normal if your horizon is longer then one iteration. Learning is proceeding fine."
"The episodes will probably terminate in a future iteration."
)
return episode_rewards


class JsonWriter:
Expand Down

0 comments on commit 39df9cc

Please sign in to comment.