Skip to content

Commit

Permalink
bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Nov 5, 2024
1 parent 747e0c4 commit 39884c8
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict, List, Optional

import torch
from models import GnnConfig, SequenceModelConfig
from tensordict import TensorDictBase
from tensordict.nn import TensorDictSequential
from torchrl.collectors import SyncDataCollector
Expand Down Expand Up @@ -361,7 +362,31 @@ def _setup(self):
self._on_setup()

def _perfrom_checks(self):
pass
model_configs = [self.model_config, self.critic_model_config]
for config in model_configs:
if isinstance(config, SequenceModelConfig):
for layer_config in self.critic_model_config.model_configs[1:]:
if (
isinstance(layer_config, GnnConfig)
and layer_config.topology == "from_pos"
):
raise ValueError(
"GNNs with topology 'from_pos' are currently only usable in first"
" layer of sequence models"
)

if self.algorithm_name in ("mappo", "ippo"):
critic_model_config = self.critic_model_config
if isinstance(self.critic_model_config, SequenceModelConfig):
critic_model_config = self.critic_model_config.model_configs[0]
if (
isinstance(critic_model_config, GnnConfig)
and critic_model_config.topology == "from_pos"
):
raise ValueError(
"GNNs in PPO critics with topology 'from_pos' are currently not available, "
"see https://github.com/pytorch/rl/issues/2537"
)

def _set_action_type(self):
if (
Expand Down

0 comments on commit 39884c8

Please sign in to comment.