-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
04622ab
commit 4beb8f7
Showing
5 changed files
with
239 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
pip install "pettingzoo[all]==1.24.3" | ||
pip install git+https://github.com/Farama-Foundation/MAgent2 | ||
|
||
sudo apt-get update | ||
sudo apt-get install python3-opengl xvfb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a single version of Python | ||
# For more information see: | ||
# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
|
||
name: magent_tests | ||
|
||
on: | ||
push: | ||
branches: [ $default-branch , "main" ] | ||
pull_request: | ||
branches: [ $default-branch , "main" ] | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
tests: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
python-version: ["3.11"] | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
bash .github/unittest/install_dependencies_nightly.sh | ||
- name: Install pettingzoo | ||
run: | | ||
bash .github/unittest/install_magent2.sh | ||
- name: Test with pytest | ||
run: | | ||
xvfb-run -s "-screen 0 1024x768x24" pytest test/test_magent.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html | ||
- name: Upload coverage to Codecov | ||
uses: codecov/codecov-action@v3 | ||
with: | ||
fail_ci_if_error: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
|
||
|
||
import pytest | ||
from benchmarl.algorithms import ( | ||
algorithm_config_registry, | ||
IppoConfig, | ||
IsacConfig, | ||
MasacConfig, | ||
QmixConfig, | ||
) | ||
from benchmarl.algorithms.common import AlgorithmConfig | ||
from benchmarl.environments import MAgentTask, Task | ||
from benchmarl.experiment import Experiment | ||
|
||
from utils import _has_magent2 | ||
from utils_experiment import ExperimentUtils | ||
|
||
|
||
@pytest.mark.skipif(not _has_magent2, reason="magent2 not found") | ||
class TestMagent: | ||
@pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) | ||
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) | ||
def test_all_algos( | ||
self, | ||
algo_config: AlgorithmConfig, | ||
task: Task, | ||
experiment_config, | ||
cnn_sequence_config, | ||
): | ||
|
||
# To not run unsupported algo-task pairs | ||
if not algo_config.supports_discrete_actions(): | ||
pytest.skip() | ||
|
||
task = task.get_from_yaml() | ||
experiment = Experiment( | ||
algorithm_config=algo_config.get_from_yaml(), | ||
model_config=cnn_sequence_config, | ||
seed=0, | ||
config=experiment_config, | ||
task=task, | ||
) | ||
experiment.run() | ||
|
||
@pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, IsacConfig]) | ||
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) | ||
def test_gnn( | ||
self, | ||
algo_config: AlgorithmConfig, | ||
task: Task, | ||
experiment_config, | ||
cnn_gnn_sequence_config, | ||
): | ||
task = task.get_from_yaml() | ||
experiment = Experiment( | ||
algorithm_config=algo_config.get_from_yaml(), | ||
model_config=cnn_gnn_sequence_config, | ||
critic_model_config=cnn_gnn_sequence_config, | ||
seed=0, | ||
config=experiment_config, | ||
task=task, | ||
) | ||
experiment.run() | ||
|
||
@pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig]) | ||
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) | ||
def test_lstm( | ||
self, | ||
algo_config: AlgorithmConfig, | ||
task: Task, | ||
experiment_config, | ||
cnn_lstm_sequence_config, | ||
): | ||
algo_config = algo_config.get_from_yaml() | ||
if algo_config.has_critic(): | ||
algo_config.share_param_critic = False | ||
experiment_config.share_policy_params = False | ||
task = task.get_from_yaml() | ||
experiment = Experiment( | ||
algorithm_config=algo_config, | ||
model_config=cnn_lstm_sequence_config, | ||
critic_model_config=cnn_lstm_sequence_config, | ||
seed=0, | ||
config=experiment_config, | ||
task=task, | ||
) | ||
experiment.run() | ||
|
||
@pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) | ||
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) | ||
def test_reloading_trainer( | ||
self, | ||
algo_config: AlgorithmConfig, | ||
task: Task, | ||
experiment_config, | ||
cnn_sequence_config, | ||
): | ||
# To not run unsupported algo-task pairs | ||
if not algo_config.supports_discrete_actions(): | ||
pytest.skip() | ||
algo_config = algo_config.get_from_yaml() | ||
|
||
ExperimentUtils.check_experiment_loading( | ||
algo_config=algo_config, | ||
model_config=cnn_sequence_config, | ||
experiment_config=experiment_config, | ||
task=task.get_from_yaml(), | ||
) | ||
|
||
@pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig]) | ||
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) | ||
@pytest.mark.parametrize("share_params", [True, False]) | ||
def test_share_policy_params( | ||
self, | ||
algo_config: AlgorithmConfig, | ||
task: Task, | ||
share_params, | ||
experiment_config, | ||
cnn_sequence_config, | ||
): | ||
experiment_config.share_policy_params = share_params | ||
task = task.get_from_yaml() | ||
experiment = Experiment( | ||
algorithm_config=algo_config.get_from_yaml(), | ||
model_config=cnn_sequence_config, | ||
seed=0, | ||
config=experiment_config, | ||
task=task, | ||
) | ||
experiment.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters