diff --git a/test/test_vmas.py b/test/test_vmas.py index 9069ff3a..07141a4c 100644 --- a/test/test_vmas.py +++ b/test/test_vmas.py @@ -2,7 +2,14 @@ import pytest -from benchmarl.algorithms import algorithm_config_registry, MappoConfig, QmixConfig +from benchmarl.algorithms import ( + algorithm_config_registry, + IppoConfig, + MaddpgConfig, + MappoConfig, + MasacConfig, + QmixConfig, +) from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import Task, VmasTask from benchmarl.experiment import Experiment @@ -50,3 +57,27 @@ def test_reloading_trainer( experiment_config=experiment_config, task=task.get_from_yaml(), ) + + @pytest.mark.parametrize( + "algo_config", [QmixConfig, IppoConfig, MaddpgConfig, MasacConfig] + ) + @pytest.mark.parametrize("task", [VmasTask.NAVIGATION]) + @pytest.mark.parametrize("share_params", [True, False]) + def test_share_policy_params( + self, + algo_config: AlgorithmConfig, + task: Task, + share_params, + experiment_config, + mlp_sequence_config, + ): + experiment_config.share_policy_params = share_params + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=mlp_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run()