Skip to content

Commit

Permalink
try less tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Nov 27, 2024
1 parent 573243f commit 64851a9
Showing 1 changed file with 87 additions and 95 deletions.
182 changes: 87 additions & 95 deletions test/test_magent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,12 @@

import pytest

from benchmarl.algorithms import (
algorithm_config_registry,
IppoConfig,
IsacConfig,
MappoConfig,
MasacConfig,
QmixConfig,
)
from benchmarl.algorithms import algorithm_config_registry
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import MAgentTask, Task
from benchmarl.experiment import Experiment

from utils import _has_magent2
from utils_experiment import ExperimentUtils


@pytest.mark.skipif(not _has_magent2, reason="magent2 not found")
Expand Down Expand Up @@ -49,89 +41,89 @@ def test_all_algos(
)
experiment.run()

@pytest.mark.parametrize("algo_config", [MappoConfig, QmixConfig, IsacConfig])
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
def test_gnn(
self,
algo_config: AlgorithmConfig,
task: Task,
experiment_config,
cnn_gnn_sequence_config,
):
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=cnn_gnn_sequence_config,
critic_model_config=cnn_gnn_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

@pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig])
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
def test_lstm(
self,
algo_config: AlgorithmConfig,
task: Task,
experiment_config,
cnn_lstm_sequence_config,
):
algo_config = algo_config.get_from_yaml()
if algo_config.has_critic():
algo_config.share_param_critic = False
experiment_config.share_policy_params = False
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config,
model_config=cnn_lstm_sequence_config,
critic_model_config=cnn_lstm_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

@pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
def test_reloading_trainer(
self,
algo_config: AlgorithmConfig,
task: Task,
experiment_config,
cnn_sequence_config,
):
# To not run unsupported algo-task pairs
if not algo_config.supports_discrete_actions():
pytest.skip()
algo_config = algo_config.get_from_yaml()

ExperimentUtils.check_experiment_loading(
algo_config=algo_config,
model_config=cnn_sequence_config,
experiment_config=experiment_config,
task=task.get_from_yaml(),
)

@pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig])
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
@pytest.mark.parametrize("share_params", [True, False])
def test_share_policy_params(
self,
algo_config: AlgorithmConfig,
task: Task,
share_params,
experiment_config,
cnn_sequence_config,
):
experiment_config.share_policy_params = share_params
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=cnn_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()
# @pytest.mark.parametrize("algo_config", [MappoConfig, QmixConfig, IsacConfig])
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
# def test_gnn(
# self,
# algo_config: AlgorithmConfig,
# task: Task,
# experiment_config,
# cnn_gnn_sequence_config,
# ):
# task = task.get_from_yaml()
# experiment = Experiment(
# algorithm_config=algo_config.get_from_yaml(),
# model_config=cnn_gnn_sequence_config,
# critic_model_config=cnn_gnn_sequence_config,
# seed=0,
# config=experiment_config,
# task=task,
# )
# experiment.run()
#
# @pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig])
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
# def test_lstm(
# self,
# algo_config: AlgorithmConfig,
# task: Task,
# experiment_config,
# cnn_lstm_sequence_config,
# ):
# algo_config = algo_config.get_from_yaml()
# if algo_config.has_critic():
# algo_config.share_param_critic = False
# experiment_config.share_policy_params = False
# task = task.get_from_yaml()
# experiment = Experiment(
# algorithm_config=algo_config,
# model_config=cnn_lstm_sequence_config,
# critic_model_config=cnn_lstm_sequence_config,
# seed=0,
# config=experiment_config,
# task=task,
# )
# experiment.run()
#
# @pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
# def test_reloading_trainer(
# self,
# algo_config: AlgorithmConfig,
# task: Task,
# experiment_config,
# cnn_sequence_config,
# ):
# # To not run unsupported algo-task pairs
# if not algo_config.supports_discrete_actions():
# pytest.skip()
# algo_config = algo_config.get_from_yaml()
#
# ExperimentUtils.check_experiment_loading(
# algo_config=algo_config,
# model_config=cnn_sequence_config,
# experiment_config=experiment_config,
# task=task.get_from_yaml(),
# )
#
# @pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig])
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
# @pytest.mark.parametrize("share_params", [True, False])
# def test_share_policy_params(
# self,
# algo_config: AlgorithmConfig,
# task: Task,
# share_params,
# experiment_config,
# cnn_sequence_config,
# ):
# experiment_config.share_policy_params = share_params
# task = task.get_from_yaml()
# experiment = Experiment(
# algorithm_config=algo_config.get_from_yaml(),
# model_config=cnn_sequence_config,
# seed=0,
# config=experiment_config,
# task=task,
# )
# experiment.run()

0 comments on commit 64851a9

Please sign in to comment.