Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Oct 4, 2024
1 parent d62b21e commit 20ad0fb
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,15 @@ def log_evaluation(
to_log = {}
json_metrics = {}
for group in self.group_map.keys():
# returns has shape (n_episodes, n_agents_in_group)
# returns has shape (n_episodes)
returns = torch.stack(
[self._get_reward(group, td).sum(0) for td in rollouts],
[self._get_reward(group, td).sum(0).mean() for td in rollouts],
dim=0,
)
self._log_min_mean_max(
to_log, f"eval/{group}/reward/episode_reward", returns
)
json_metrics[group + "_return"] = returns.mean(
dim=tuple(range(1, returns.ndim))
) # result has shape (n_episodes) as we take the mean over agents in the group
json_metrics[group + "_return"] = returns

mean_group_return = self._log_global_episode_reward(
list(json_metrics.values()), to_log, prefix="eval"
Expand Down Expand Up @@ -318,16 +316,15 @@ def _log_individual_and_group_rewards(
episode_reward[..., i, :][agent_global_done],
)

# 2. Here we log rewards from group data
# 2. Here we log rewards from group data taking the mean over agents
group_episode_reward = episode_reward.mean(-2)[global_done]
if any_episode_ended:
group_episode_reward = episode_reward[unsqueeze_global_done]
self._log_min_mean_max(
to_log, f"{prefix}/{group}/reward/episode_reward", group_episode_reward
)
self._log_min_mean_max(to_log, f"{prefix}/reward/reward", reward)

# 3. We take the mean over agents in the group so that we will later log from a global perspecitve
return episode_reward.mean(-2)[global_done]
return group_episode_reward

def _log_global_episode_reward(
self, episode_rewards: List[Tensor], to_log: Dict[str, Tensor], prefix: str
Expand Down

0 comments on commit 20ad0fb

Please sign in to comment.