diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index ce2d40eb..ff0c47cf 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -26,10 +26,13 @@ from torchrl.record.loggers import generate_exp_name from tqdm import tqdm +from benchmarl.algorithms import IppoConfig, MappoConfig + from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import Task from benchmarl.experiment.callback import Callback, CallbackNotifier from benchmarl.experiment.logger import Logger +from benchmarl.models import GnnConfig, SequenceModelConfig from benchmarl.models.common import ModelConfig from benchmarl.utils import _read_yaml_config, seed_everything @@ -361,7 +364,30 @@ def _setup(self): self._on_setup() def _perfrom_checks(self): - pass + for config in (self.model_config, self.critic_model_config): + if isinstance(config, SequenceModelConfig): + for layer_config in config.model_configs[1:]: + if isinstance(layer_config, GnnConfig) and ( + layer_config.position_key is not None + or layer_config.velocity_key is not None + ): + raise ValueError( + "GNNs reading position or velocity keys are currently only usable in first" + " layer of sequence models" + ) + + if self.algorithm_config in (MappoConfig, IppoConfig): + critic_model_config = self.critic_model_config + if isinstance(critic_model_config, SequenceModelConfig): + critic_model_config = self.critic_model_config.model_configs[0] + if ( + isinstance(critic_model_config, GnnConfig) + and critic_model_config.topology == "from_pos" + ): + raise ValueError( + "GNNs in PPO critics with topology 'from_pos' are currently not available, " + "see https://github.com/pytorch/rl/issues/2537" + ) def _set_action_type(self): if (