Skip to content

Commit

Permalink
ci
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Nov 24, 2024
1 parent 04622ab commit 4beb8f7
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .github/unittest/install_magent2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

pip install "pettingzoo[all]==1.24.3"
pip install git+https://github.com/Farama-Foundation/MAgent2

sudo apt-get update
sudo apt-get install python3-opengl xvfb
43 changes: 43 additions & 0 deletions .github/workflows/magent_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see:
# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions


name: magent_tests

on:
push:
branches: [ $default-branch , "main" ]
pull_request:
branches: [ $default-branch , "main" ]

permissions:
contents: read

jobs:
tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
bash .github/unittest/install_dependencies_nightly.sh
- name: Install pettingzoo
run: |
bash .github/unittest/install_magent2.sh
- name: Test with pytest
run: |
xvfb-run -s "-screen 0 1024x768x24" pytest test/test_magent.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
fail_ci_if_error: false
54 changes: 54 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,31 @@ def mlp_gnn_sequence_config() -> ModelConfig:
)


@pytest.fixture
def cnn_gnn_sequence_config() -> ModelConfig:
return SequenceModelConfig(
model_configs=[
CnnConfig(
cnn_num_cells=[4, 3],
cnn_kernel_sizes=[3, 2],
cnn_strides=1,
cnn_paddings=0,
cnn_activation_class=nn.Tanh,
mlp_num_cells=[4],
mlp_activation_class=nn.Tanh,
mlp_layer_class=nn.Linear,
),
GnnConfig(
topology="full",
self_loops=False,
gnn_class=torch_geometric.nn.conv.GATv2Conv,
),
MlpConfig(num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear),
],
intermediate_sizes=[5, 3],
)


@pytest.fixture
def gru_mlp_sequence_config() -> ModelConfig:
return SequenceModelConfig(
Expand Down Expand Up @@ -128,3 +153,32 @@ def lstm_mlp_sequence_config() -> ModelConfig:
],
intermediate_sizes=[5],
)


@pytest.fixture
def cnn_lstm_sequence_config() -> ModelConfig:
return SequenceModelConfig(
model_configs=[
CnnConfig(
cnn_num_cells=[4, 3],
cnn_kernel_sizes=[3, 2],
cnn_strides=1,
cnn_paddings=0,
cnn_activation_class=nn.Tanh,
mlp_num_cells=[4],
mlp_activation_class=nn.Tanh,
mlp_layer_class=nn.Linear,
),
LstmConfig(
hidden_size=13,
mlp_num_cells=[],
mlp_activation_class=nn.Tanh,
mlp_layer_class=nn.Linear,
n_layers=1,
bias=True,
dropout=0,
compile=False,
),
],
intermediate_sizes=[5],
)
135 changes: 135 additions & 0 deletions test/test_magent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


import pytest
from benchmarl.algorithms import (
algorithm_config_registry,
IppoConfig,
IsacConfig,
MasacConfig,
QmixConfig,
)
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import MAgentTask, Task
from benchmarl.experiment import Experiment

from utils import _has_magent2
from utils_experiment import ExperimentUtils


@pytest.mark.skipif(not _has_magent2, reason="magent2 not found")
class TestMagent:
@pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
def test_all_algos(
self,
algo_config: AlgorithmConfig,
task: Task,
experiment_config,
cnn_sequence_config,
):

# To not run unsupported algo-task pairs
if not algo_config.supports_discrete_actions():
pytest.skip()

task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=cnn_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

@pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, IsacConfig])
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
def test_gnn(
self,
algo_config: AlgorithmConfig,
task: Task,
experiment_config,
cnn_gnn_sequence_config,
):
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=cnn_gnn_sequence_config,
critic_model_config=cnn_gnn_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

@pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig])
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
def test_lstm(
self,
algo_config: AlgorithmConfig,
task: Task,
experiment_config,
cnn_lstm_sequence_config,
):
algo_config = algo_config.get_from_yaml()
if algo_config.has_critic():
algo_config.share_param_critic = False
experiment_config.share_policy_params = False
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config,
model_config=cnn_lstm_sequence_config,
critic_model_config=cnn_lstm_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

@pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
def test_reloading_trainer(
self,
algo_config: AlgorithmConfig,
task: Task,
experiment_config,
cnn_sequence_config,
):
# To not run unsupported algo-task pairs
if not algo_config.supports_discrete_actions():
pytest.skip()
algo_config = algo_config.get_from_yaml()

ExperimentUtils.check_experiment_loading(
algo_config=algo_config,
model_config=cnn_sequence_config,
experiment_config=experiment_config,
task=task.get_from_yaml(),
)

@pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig])
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
@pytest.mark.parametrize("share_params", [True, False])
def test_share_policy_params(
self,
algo_config: AlgorithmConfig,
task: Task,
share_params,
experiment_config,
cnn_sequence_config,
):
experiment_config.share_policy_params = share_params
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=cnn_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()
1 change: 1 addition & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
_has_smacv2 = importlib.util.find_spec("smacv2") is not None
_has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None
_has_meltingpot = importlib.util.find_spec("meltingpot") is not None
_has_magent2 = importlib.util.find_spec("magent2") is not None

0 comments on commit 4beb8f7

Please sign in to comment.