Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Sep 21, 2023
1 parent bfe7297 commit 12a8e81
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion test/test_vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

import pytest

from benchmarl.algorithms import algorithm_config_registry, MappoConfig, QmixConfig
from benchmarl.algorithms import (
algorithm_config_registry,
IppoConfig,
MaddpgConfig,
MappoConfig,
MasacConfig,
QmixConfig,
)
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task, VmasTask
from benchmarl.experiment import Experiment
Expand Down Expand Up @@ -50,3 +57,27 @@ def test_reloading_trainer(
experiment_config=experiment_config,
task=task.get_from_yaml(),
)

@pytest.mark.parametrize(
"algo_config", [QmixConfig, IppoConfig, MaddpgConfig, MasacConfig]
)
@pytest.mark.parametrize("task", [VmasTask.NAVIGATION])
@pytest.mark.parametrize("share_params", [True, False])
def test_share_policy_params(
self,
algo_config: AlgorithmConfig,
task: Task,
share_params,
experiment_config,
mlp_sequence_config,
):
experiment_config.share_policy_params = share_params
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=mlp_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

0 comments on commit 12a8e81

Please sign in to comment.