From 6edcdc170e52f4c746b68a2c9efb9c134c8c0ec0 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Sat, 27 Jul 2024 20:58:25 +0200 Subject: [PATCH] [BugFix] Fix collect with grad (#114) * amend * amend * amend --- benchmarl/conf/experiment/base_experiment.yaml | 2 +- benchmarl/experiment/experiment.py | 7 ++++++- test/test_vmas.py | 18 ++++++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index bb6f6d94..89566850 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -82,7 +82,7 @@ render: True evaluation_interval: 120_000 # Number of episodes that evaluation is run on evaluation_episodes: 10 -# If True, when stochastic policies are evaluated, their mode is taken, otherwise, if False, they are sampled +# If True, when stochastic policies are evaluated, their deterministic value is taken, otherwise, if False, they are sampled evaluation_deterministic_actions: True # List of loggers to use, options are: wandb, csv, tensorboard, mflow diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 90ef568b..b09a7abc 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -582,7 +582,12 @@ def _collection_loop(self): auto_reset=False, tensordict=reset_batch, ) - reset_batch = step_mdp(batch[..., -1]) + reset_batch = step_mdp( + batch[..., -1], + reward_keys=self.rollout_env.reward_keys, + action_keys=self.rollout_env.action_keys, + done_keys=self.rollout_env.done_keys, + ) # Logging collection collection_time = time.time() - iteration_start diff --git a/test/test_vmas.py b/test/test_vmas.py index 3dfc4f81..9a6eb6cc 100644 --- a/test/test_vmas.py +++ b/test/test_vmas.py @@ -74,6 +74,24 @@ def test_all_tasks( ) experiment.run() + def test_collect_with_grad( + self, + experiment_config, + mlp_sequence_config, + algo_config: AlgorithmConfig = IppoConfig, + task: Task = VmasTask.BALANCE, + ): + task = task.get_from_yaml() + experiment_config.collect_with_grad = True + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=mlp_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + @pytest.mark.parametrize( "algo_config", [IppoConfig, QmixConfig, IsacConfig, IddpgConfig] )