From 39884c8af3e18282adc2205f34bd4f90d7044c70 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Nov 2024 20:26:43 +0000 Subject: [PATCH 1/8] bugs --- benchmarl/experiment/experiment.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index ce2d40eb..56e8634a 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional import torch +from models import GnnConfig, SequenceModelConfig from tensordict import TensorDictBase from tensordict.nn import TensorDictSequential from torchrl.collectors import SyncDataCollector @@ -361,7 +362,31 @@ def _setup(self): self._on_setup() def _perfrom_checks(self): - pass + model_configs = [self.model_config, self.critic_model_config] + for config in model_configs: + if isinstance(config, SequenceModelConfig): + for layer_config in self.critic_model_config.model_configs[1:]: + if ( + isinstance(layer_config, GnnConfig) + and layer_config.topology == "from_pos" + ): + raise ValueError( + "GNNs with topology 'from_pos' are currently only usable in first" + " layer of sequence models" + ) + + if self.algorithm_name in ("mappo", "ippo"): + critic_model_config = self.critic_model_config + if isinstance(self.critic_model_config, SequenceModelConfig): + critic_model_config = self.critic_model_config.model_configs[0] + if ( + isinstance(critic_model_config, GnnConfig) + and critic_model_config.topology == "from_pos" + ): + raise ValueError( + "GNNs in PPO critics with topology 'from_pos' are currently not available, " + "see https://github.com/pytorch/rl/issues/2537" + ) def _set_action_type(self): if ( From 4b9992e8500b33d0aa0a9a3723fcb10b867d097c Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Nov 2024 20:30:34 +0000 Subject: [PATCH 2/8] bugs --- benchmarl/experiment/experiment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 56e8634a..1f453c1f 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -366,9 +366,9 @@ def _perfrom_checks(self): for config in model_configs: if isinstance(config, SequenceModelConfig): for layer_config in self.critic_model_config.model_configs[1:]: - if ( - isinstance(layer_config, GnnConfig) - and layer_config.topology == "from_pos" + if isinstance(layer_config, GnnConfig) and ( + layer_config.position_key is not None + or layer_config.velocity_key is not None ): raise ValueError( "GNNs with topology 'from_pos' are currently only usable in first" From 9ff1fb7117982298ca2c9b21d67443c6392789e1 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Nov 2024 20:31:00 +0000 Subject: [PATCH 3/8] bugs --- benchmarl/experiment/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 1f453c1f..22264b3f 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -371,7 +371,7 @@ def _perfrom_checks(self): or layer_config.velocity_key is not None ): raise ValueError( - "GNNs with topology 'from_pos' are currently only usable in first" + "GNNs reading position or velocity keys are currently only usable in first" " layer of sequence models" ) From 1811f6ed0467475051c4205f4230312c37e0a8f5 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Nov 2024 20:31:55 +0000 Subject: [PATCH 4/8] bugs --- benchmarl/experiment/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 22264b3f..c1d2b593 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -365,7 +365,7 @@ def _perfrom_checks(self): model_configs = [self.model_config, self.critic_model_config] for config in model_configs: if isinstance(config, SequenceModelConfig): - for layer_config in self.critic_model_config.model_configs[1:]: + for layer_config in config[1:]: if isinstance(layer_config, GnnConfig) and ( layer_config.position_key is not None or layer_config.velocity_key is not None From 4a8af84f589f2257ae00b78f9c0027fba6307d84 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Nov 2024 20:33:06 +0000 Subject: [PATCH 5/8] bugs --- benchmarl/experiment/experiment.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index c1d2b593..b75de28c 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -362,10 +362,9 @@ def _setup(self): self._on_setup() def _perfrom_checks(self): - model_configs = [self.model_config, self.critic_model_config] - for config in model_configs: + for config in (self.model_config, self.critic_model_config): if isinstance(config, SequenceModelConfig): - for layer_config in config[1:]: + for layer_config in config.model_configs[1:]: if isinstance(layer_config, GnnConfig) and ( layer_config.position_key is not None or layer_config.velocity_key is not None @@ -377,7 +376,7 @@ def _perfrom_checks(self): if self.algorithm_name in ("mappo", "ippo"): critic_model_config = self.critic_model_config - if isinstance(self.critic_model_config, SequenceModelConfig): + if isinstance(critic_model_config, SequenceModelConfig): critic_model_config = self.critic_model_config.model_configs[0] if ( isinstance(critic_model_config, GnnConfig) From bceb7e6b98a656dc68c090df7cf03c020d0a0a6d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 6 Nov 2024 09:34:21 +0000 Subject: [PATCH 6/8] bugs --- benchmarl/experiment/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index b75de28c..39b7e8ab 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -17,7 +17,6 @@ from typing import Any, Dict, List, Optional import torch -from models import GnnConfig, SequenceModelConfig from tensordict import TensorDictBase from tensordict.nn import TensorDictSequential from torchrl.collectors import SyncDataCollector @@ -31,6 +30,7 @@ from benchmarl.environments import Task from benchmarl.experiment.callback import Callback, CallbackNotifier from benchmarl.experiment.logger import Logger +from benchmarl.models import GnnConfig, SequenceModelConfig from benchmarl.models.common import ModelConfig from benchmarl.utils import _read_yaml_config, seed_everything From ecf6c643411347e96317cd95af48d46a335dedfa Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 6 Nov 2024 09:46:18 +0000 Subject: [PATCH 7/8] bugs --- benchmarl/experiment/experiment.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 39b7e8ab..645fdb3b 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional import torch +from algorithms import IppoConfig, MappoConfig from tensordict import TensorDictBase from tensordict.nn import TensorDictSequential from torchrl.collectors import SyncDataCollector @@ -374,7 +375,7 @@ def _perfrom_checks(self): " layer of sequence models" ) - if self.algorithm_name in ("mappo", "ippo"): + if self.algorithm_config in (MappoConfig, IppoConfig): critic_model_config = self.critic_model_config if isinstance(critic_model_config, SequenceModelConfig): critic_model_config = self.critic_model_config.model_configs[0] From 2d3a0e2a916f077c45ee774a8f9f06f2e2e8d139 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 6 Nov 2024 09:46:41 +0000 Subject: [PATCH 8/8] bugs --- benchmarl/experiment/experiment.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 645fdb3b..ff0c47cf 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -17,7 +17,6 @@ from typing import Any, Dict, List, Optional import torch -from algorithms import IppoConfig, MappoConfig from tensordict import TensorDictBase from tensordict.nn import TensorDictSequential from torchrl.collectors import SyncDataCollector @@ -27,6 +26,8 @@ from torchrl.record.loggers import generate_exp_name from tqdm import tqdm +from benchmarl.algorithms import IppoConfig, MappoConfig + from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import Task from benchmarl.experiment.callback import Callback, CallbackNotifier