From 39884c8af3e18282adc2205f34bd4f90d7044c70 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Nov 2024 20:26:43 +0000 Subject: [PATCH] bugs --- benchmarl/experiment/experiment.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index ce2d40eb..56e8634a 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional import torch +from models import GnnConfig, SequenceModelConfig from tensordict import TensorDictBase from tensordict.nn import TensorDictSequential from torchrl.collectors import SyncDataCollector @@ -361,7 +362,31 @@ def _setup(self): self._on_setup() def _perfrom_checks(self): - pass + model_configs = [self.model_config, self.critic_model_config] + for config in model_configs: + if isinstance(config, SequenceModelConfig): + for layer_config in self.critic_model_config.model_configs[1:]: + if ( + isinstance(layer_config, GnnConfig) + and layer_config.topology == "from_pos" + ): + raise ValueError( + "GNNs with topology 'from_pos' are currently only usable in first" + " layer of sequence models" + ) + + if self.algorithm_name in ("mappo", "ippo"): + critic_model_config = self.critic_model_config + if isinstance(self.critic_model_config, SequenceModelConfig): + critic_model_config = self.critic_model_config.model_configs[0] + if ( + isinstance(critic_model_config, GnnConfig) + and critic_model_config.topology == "from_pos" + ): + raise ValueError( + "GNNs in PPO critics with topology 'from_pos' are currently not available, " + "see https://github.com/pytorch/rl/issues/2537" + ) def _set_action_type(self): if (