diff --git a/.github/unittest/install_dependencies.sh b/.github/unittest/install_dependencies.sh new file mode 100644 index 00000000..e44c5bd9 --- /dev/null +++ b/.github/unittest/install_dependencies.sh @@ -0,0 +1,16 @@ + + +python -m pip install --upgrade pip +python -m pip install flake8 pytest pytest-cov hydra-core + +if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + +python -m pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall + +cd .. +python -m pip install git+https://github.com/pytorch-labs/tensordict.git +git clone https://github.com/pytorch/rl.git +cd rl +python setup.py develop +cd ../BenchMARL +pip install -e . diff --git a/.github/unittest/install_smacv2.sh b/.github/unittest/install_smacv2.sh new file mode 100644 index 00000000..fd57cf0a --- /dev/null +++ b/.github/unittest/install_smacv2.sh @@ -0,0 +1,22 @@ + + +root_dir="$(git rev-parse --show-toplevel)" +cd "${root_dir}" + +starcraft_path="${root_dir}/StarCraftII" +map_dir="${starcraft_path}/Maps" +printf "* Installing StarCraft 2 and SMACv2 maps into ${starcraft_path}\n" +cd "${root_dir}" +wget https://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip +# The archive contains StarCraftII folder. Password comes from the documentation. +unzip -qo -P iagreetotheeula SC2.4.10.zip +mkdir -p "${map_dir}" +# Install Maps +wget https://github.com/oxwhirl/smacv2/releases/download/maps/SMAC_Maps.zip +unzip SMAC_Maps.zip +mkdir "${map_dir}/SMAC_Maps" +mv *.SC2Map "${map_dir}/SMAC_Maps" +printf "StarCraft II and SMAC are installed." + +pip install numpy==1.23.0 +pip install git+https://github.com/oxwhirl/smacv2.git diff --git a/.github/unittest/install_vmas.sh b/.github/unittest/install_vmas.sh new file mode 100644 index 00000000..86c7ebf4 --- /dev/null +++ b/.github/unittest/install_vmas.sh @@ -0,0 +1,4 @@ + +python -m pip install vmas +sudo apt-get update +sudo apt-get install python3-opengl xvfb diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..ff470a37 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,45 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: +# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + + +name: lint + +on: + push: + branches: [ $default-branch , "main" ] + pull_request: + branches: [ $default-branch , "main" ] + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Lint + run: | + python -m pip install --upgrade pip + pip install pre-commit + + set +e + pre-commit run --all-files + + if [ $? -ne 0 ]; then + git --no-pager diff + exit 1 + fi + + diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml deleted file mode 100644 index a9f27c8b..00000000 --- a/.github/workflows/python-app.yml +++ /dev/null @@ -1,58 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a single version of Python -# For more information see: -# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - - -name: pytest - -on: - push: - branches: [ $default-branch , "main" , "dev" ] - pull_request: - branches: [ $default-branch , "main" ] - -permissions: - contents: read - -jobs: - build: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.8", "3.9", "3.10"] - - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest - python -m pip install vmas - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - python -m pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall - - cd .. - python -m pip install git+https://github.com/pytorch-labs/tensordict.git - git clone https://github.com/pytorch/rl.git - cd rl - python setup.py develop - cd ../BenchMARL - - pip install -e . - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 benchmarl/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 benchmarl/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --ignore=E203,W503 - - name: Test with pytest - run: | - pip install pytest - pip install pytest-cov - pip install tqdm - pytest test/ --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html diff --git a/.github/workflows/smacv2_tests.yml b/.github/workflows/smacv2_tests.yml new file mode 100644 index 00000000..90fd595f --- /dev/null +++ b/.github/workflows/smacv2_tests.yml @@ -0,0 +1,45 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: +# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + + +name: smacv2_tests + +on: + push: + branches: [ $default-branch , "main" ] + pull_request: + branches: [ $default-branch , "main" ] + +permissions: + contents: read + +jobs: + build: + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'smacv2') }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + bash .github/unittest/install_dependencies.sh + - name: Install smacv2 + run: | + bash .github/unittest/install_smacv2.sh + + - name: Test with pytest + run: | + root_dir="$(git rev-parse --show-toplevel)" + export SC2PATH="${root_dir}/StarCraftII" + echo 'SC2PATH is set to ' "$SC2PATH" + + pytest test/test_smacv2.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml new file mode 100644 index 00000000..4f981034 --- /dev/null +++ b/.github/workflows/unit_tests.yml @@ -0,0 +1,37 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: +# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + + +name: unit_tests + +on: + push: + branches: [ $default-branch , "main" ] + pull_request: + branches: [ $default-branch , "main" ] + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + bash .github/unittest/install_dependencies.sh + + - name: Test with pytest + run: | + pytest test/ --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html diff --git a/.github/workflows/vmas_tests.yml b/.github/workflows/vmas_tests.yml new file mode 100644 index 00000000..8925918a --- /dev/null +++ b/.github/workflows/vmas_tests.yml @@ -0,0 +1,39 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: +# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + + +name: vmas_tests + +on: + push: + branches: [ $default-branch , "main" ] + pull_request: + branches: [ $default-branch , "main" ] + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + bash .github/unittest/install_dependencies.sh + - name: Install vmas + run: | + bash .github/unittest/install_vmas.sh + - name: Test with pytest + run: | + xvfb-run -s "-screen 0 1024x768x24" pytest test/test_vmas.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html diff --git a/.gitignore b/.gitignore index 9b8f40aa..f062df2d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,12 @@ # Hydra outputs/ +examples/outputs/ +benchmarl/outputs/ multirun/ +examples/multirun/ +benchmarl/multirun/ + # Byte-compiled / optimized / DLL files __pycache__/ @@ -43,6 +48,7 @@ pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports +test/tmp/ htmlcov/ .tox/ .nox/ diff --git a/benchmarl/__init__.py b/benchmarl/__init__.py index 5b0d31ee..9cb5695f 100644 --- a/benchmarl/__init__.py +++ b/benchmarl/__init__.py @@ -1,20 +1,25 @@ -def load_hydra_schemas(): - from hydra.core.config_store import ConfigStore +import importlib - from benchmarl.algorithms import algorithm_config_registry - from benchmarl.environments import _task_class_registry - from benchmarl.experiment import ExperimentConfig +_has_hydra = importlib.util.find_spec("hydra") is not None - # Create instance to load hydra schemas - cs = ConfigStore.instance() - # Load experiment schema - cs.store(name="experiment_config", group="experiment", node=ExperimentConfig) - # Load algos schemas - for algo_name, algo_schema in algorithm_config_registry.items(): - cs.store(name=f"{algo_name}_config", group="algorithm", node=algo_schema) - # Load rask schemas - for task_schema_name, task_schema in _task_class_registry.items(): - cs.store(name=task_schema_name, group="task", node=task_schema) +if _has_hydra: + def load_hydra_schemas(): + from hydra.core.config_store import ConfigStore -load_hydra_schemas() + from benchmarl.algorithms import algorithm_config_registry + from benchmarl.environments import _task_class_registry + from benchmarl.experiment import ExperimentConfig + + # Create instance to load hydra schemas + cs = ConfigStore.instance() + # Load experiment schema + cs.store(name="experiment_config", group="experiment", node=ExperimentConfig) + # Load algos schemas + for algo_name, algo_schema in algorithm_config_registry.items(): + cs.store(name=f"{algo_name}_config", group="algorithm", node=algo_schema) + # Load task schemas + for task_schema_name, task_schema in _task_class_registry.items(): + cs.store(name=task_schema_name, group="task", node=task_schema) + + load_hydra_schemas() diff --git a/benchmarl/algorithms/iddpg.py b/benchmarl/algorithms/iddpg.py index f4edae55..631713ce 100644 --- a/benchmarl/algorithms/iddpg.py +++ b/benchmarl/algorithms/iddpg.py @@ -29,16 +29,10 @@ class Iddpg(Algorithm): def __init__( - self, - share_param_actor: bool = True, - share_param_critic: bool = True, - loss_function: str = "l2", - delay_value: bool = True, - **kwargs + self, share_param_critic: bool, loss_function: str, delay_value: bool, **kwargs ): super().__init__(**kwargs) - self.share_param_actor = share_param_actor self.share_param_critic = share_param_critic self.delay_value = delay_value self.loss_function = loss_function @@ -142,7 +136,7 @@ def _get_policy_for_loss( input_has_agent_dim=True, n_agents=n_agents, centralised=False, - share_params=self.share_param_actor, + share_params=self.experiment_config.share_policy_params, device=self.device, ) @@ -258,7 +252,6 @@ def get_value_module(self, group: str) -> TensorDictModule: @dataclass class IddpgConfig(AlgorithmConfig): - share_param_actor: bool = MISSING share_param_critic: bool = MISSING loss_function: str = MISSING delay_value: bool = MISSING diff --git a/benchmarl/algorithms/ippo.py b/benchmarl/algorithms/ippo.py index 0a838648..dd06e422 100644 --- a/benchmarl/algorithms/ippo.py +++ b/benchmarl/algorithms/ippo.py @@ -27,7 +27,6 @@ class Ippo(Algorithm): def __init__( self, - share_param_actor: bool, share_param_critic: bool, clip_epsilon: float, entropy_coef: bool, @@ -38,7 +37,6 @@ def __init__( ): super().__init__(**kwargs) - self.share_param_actor = share_param_actor self.share_param_critic = share_param_critic self.clip_epsilon = clip_epsilon self.entropy_coef = entropy_coef @@ -147,7 +145,7 @@ def _get_policy_for_loss( input_has_agent_dim=True, n_agents=n_agents, centralised=False, - share_params=self.share_param_actor, + share_params=self.experiment_config.share_policy_params, device=self.device, ) @@ -286,7 +284,6 @@ def get_critic(self, group: str) -> TensorDictModule: @dataclass class IppoConfig(AlgorithmConfig): - share_param_actor: bool = MISSING share_param_critic: bool = MISSING clip_epsilon: float = MISSING entropy_coef: float = MISSING diff --git a/benchmarl/algorithms/iql.py b/benchmarl/algorithms/iql.py index ada113e8..a36a56e2 100644 --- a/benchmarl/algorithms/iql.py +++ b/benchmarl/algorithms/iql.py @@ -22,14 +22,11 @@ class Iql(Algorithm): - def __init__( - self, delay_value: bool, loss_function: str, share_params: bool, **kwargs - ): + def __init__(self, delay_value: bool, loss_function: str, **kwargs): super().__init__(**kwargs) self.delay_value = delay_value self.loss_function = loss_function - self.share_params = share_params ############################# # Overridden abstract methods @@ -129,7 +126,7 @@ def _get_policy_for_loss( input_has_agent_dim=True, n_agents=n_agents, centralised=False, - share_params=self.share_params, + share_params=self.experiment_config.share_policy_params, device=self.device, ) if self.action_mask_spec is not None: @@ -200,7 +197,6 @@ class IqlConfig(AlgorithmConfig): delay_value: bool = MISSING loss_function: str = MISSING - share_params: bool = MISSING @staticmethod def associated_class() -> Type[Algorithm]: diff --git a/benchmarl/algorithms/isac.py b/benchmarl/algorithms/isac.py index e09d5717..3f59775e 100644 --- a/benchmarl/algorithms/isac.py +++ b/benchmarl/algorithms/isac.py @@ -32,21 +32,19 @@ class Isac(Algorithm): def __init__( self, - share_param_actor: bool = True, - share_param_critic: bool = True, - num_qvalue_nets: int = 2, - loss_function: str = "l2", - delay_qvalue: bool = True, - target_entropy: Union[float, str] = "auto", - alpha_init: float = 1.0, - min_alpha: Optional[float] = None, - max_alpha: Optional[float] = None, - fixed_alpha: bool = False, + share_param_critic: bool, + num_qvalue_nets: int, + loss_function: str, + delay_qvalue: bool, + target_entropy: Union[float, str], + alpha_init: float, + min_alpha: Optional[float], + max_alpha: Optional[float], + fixed_alpha: bool, **kwargs ): super().__init__(**kwargs) - self.share_param_actor = share_param_actor self.share_param_critic = share_param_critic self.delay_qvalue = delay_qvalue self.num_qvalue_nets = num_qvalue_nets @@ -192,7 +190,7 @@ def _get_policy_for_loss( input_has_agent_dim=True, n_agents=n_agents, centralised=False, - share_params=self.share_param_actor, + share_params=self.experiment_config.share_policy_params, device=self.device, ) @@ -374,7 +372,6 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule: @dataclass class IsacConfig(AlgorithmConfig): - share_param_actor: bool = MISSING share_param_critic: bool = MISSING num_qvalue_nets: int = MISSING diff --git a/benchmarl/algorithms/maddpg.py b/benchmarl/algorithms/maddpg.py index fc083b8f..a5a20b03 100644 --- a/benchmarl/algorithms/maddpg.py +++ b/benchmarl/algorithms/maddpg.py @@ -29,16 +29,10 @@ class Maddpg(Algorithm): def __init__( - self, - share_param_actor: bool = True, - share_param_critic: bool = True, - loss_function: str = "l2", - delay_value: bool = True, - **kwargs + self, share_param_critic: bool, loss_function: str, delay_value: bool, **kwargs ): super().__init__(**kwargs) - self.share_param_actor = share_param_actor self.share_param_critic = share_param_critic self.delay_value = delay_value self.loss_function = loss_function @@ -142,7 +136,7 @@ def _get_policy_for_loss( input_has_agent_dim=True, n_agents=n_agents, centralised=False, - share_params=self.share_param_actor, + share_params=self.experiment_config.share_policy_params, device=self.device, ) @@ -314,7 +308,6 @@ def get_value_module(self, group: str) -> TensorDictModule: @dataclass class MaddpgConfig(AlgorithmConfig): - share_param_actor: bool = MISSING share_param_critic: bool = MISSING loss_function: str = MISSING diff --git a/benchmarl/algorithms/mappo.py b/benchmarl/algorithms/mappo.py index 34223fbc..22aee608 100644 --- a/benchmarl/algorithms/mappo.py +++ b/benchmarl/algorithms/mappo.py @@ -26,7 +26,6 @@ class Mappo(Algorithm): def __init__( self, - share_param_actor: bool, share_param_critic: bool, clip_epsilon: float, entropy_coef: bool, @@ -37,7 +36,6 @@ def __init__( ): super().__init__(**kwargs) - self.share_param_actor = share_param_actor self.share_param_critic = share_param_critic self.clip_epsilon = clip_epsilon self.entropy_coef = entropy_coef @@ -146,7 +144,7 @@ def _get_policy_for_loss( input_has_agent_dim=True, n_agents=n_agents, centralised=False, - share_params=self.share_param_actor, + share_params=self.experiment_config.share_policy_params, device=self.device, ) @@ -316,7 +314,6 @@ def get_critic(self, group: str) -> TensorDictModule: @dataclass class MappoConfig(AlgorithmConfig): - share_param_actor: bool = MISSING share_param_critic: bool = MISSING clip_epsilon: float = MISSING entropy_coef: float = MISSING diff --git a/benchmarl/algorithms/masac.py b/benchmarl/algorithms/masac.py index cf489d5e..eadf2b4e 100644 --- a/benchmarl/algorithms/masac.py +++ b/benchmarl/algorithms/masac.py @@ -32,21 +32,19 @@ class Masac(Algorithm): def __init__( self, - share_param_actor: bool = True, - share_param_critic: bool = True, - num_qvalue_nets: int = 2, - loss_function: str = "l2", - delay_qvalue: bool = True, - target_entropy: Union[float, str] = "auto", - alpha_init: float = 1.0, - min_alpha: Optional[float] = None, - max_alpha: Optional[float] = None, - fixed_alpha: bool = False, + share_param_critic: bool, + num_qvalue_nets: int, + loss_function: str, + delay_qvalue: bool, + target_entropy: Union[float, str], + alpha_init: float, + min_alpha: Optional[float], + max_alpha: Optional[float], + fixed_alpha: bool, **kwargs ): super().__init__(**kwargs) - self.share_param_actor = share_param_actor self.share_param_critic = share_param_critic self.delay_qvalue = delay_qvalue self.num_qvalue_nets = num_qvalue_nets @@ -192,7 +190,7 @@ def _get_policy_for_loss( input_has_agent_dim=True, n_agents=n_agents, centralised=False, - share_params=self.share_param_actor, + share_params=self.experiment_config.share_policy_params, device=self.device, ) @@ -454,7 +452,6 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule: @dataclass class MasacConfig(AlgorithmConfig): - share_param_actor: bool = MISSING share_param_critic: bool = MISSING num_qvalue_nets: int = MISSING diff --git a/benchmarl/algorithms/qmix.py b/benchmarl/algorithms/qmix.py index a2849b38..64388f2b 100644 --- a/benchmarl/algorithms/qmix.py +++ b/benchmarl/algorithms/qmix.py @@ -23,19 +23,13 @@ class Qmix(Algorithm): def __init__( - self, - mixing_embed_dim: int, - delay_value: bool, - loss_function: str, - share_params: bool, - **kwargs + self, mixing_embed_dim: int, delay_value: bool, loss_function: str, **kwargs ): super().__init__(**kwargs) self.delay_value = delay_value self.loss_function = loss_function self.mixing_embed_dim = mixing_embed_dim - self.share_params = share_params ############################# # Overridden abstract methods @@ -137,7 +131,7 @@ def _get_policy_for_loss( input_has_agent_dim=True, n_agents=n_agents, centralised=False, - share_params=self.share_params, + share_params=self.experiment_config.share_policy_params, device=self.device, ) if self.action_mask_spec is not None: @@ -232,7 +226,6 @@ class QmixConfig(AlgorithmConfig): mixing_embed_dim: int = MISSING delay_value: bool = MISSING loss_function: str = MISSING - share_params: bool = MISSING @staticmethod def associated_class() -> Type[Algorithm]: diff --git a/benchmarl/algorithms/vdn.py b/benchmarl/algorithms/vdn.py index faf0e8b9..64f5141e 100644 --- a/benchmarl/algorithms/vdn.py +++ b/benchmarl/algorithms/vdn.py @@ -22,14 +22,11 @@ class Vdn(Algorithm): - def __init__( - self, delay_value: bool, loss_function: str, share_params: bool, **kwargs - ): + def __init__(self, delay_value: bool, loss_function: str, **kwargs): super().__init__(**kwargs) self.delay_value = delay_value self.loss_function = loss_function - self.share_params = share_params ############################# # Overridden abstract methods @@ -131,7 +128,7 @@ def _get_policy_for_loss( input_has_agent_dim=True, n_agents=n_agents, centralised=False, - share_params=self.share_params, + share_params=self.experiment_config.share_policy_params, device=self.device, ) if self.action_mask_spec is not None: @@ -215,7 +212,6 @@ class VdnConfig(AlgorithmConfig): delay_value: bool = MISSING loss_function: str = MISSING - share_params: bool = MISSING @staticmethod def associated_class() -> Type[Algorithm]: diff --git a/benchmarl/conf/algorithm/iddpg.yaml b/benchmarl/conf/algorithm/iddpg.yaml index 6cd946fb..84a2db98 100644 --- a/benchmarl/conf/algorithm/iddpg.yaml +++ b/benchmarl/conf/algorithm/iddpg.yaml @@ -3,7 +3,6 @@ defaults: - _self_ -share_param_actor: True share_param_critic: True loss_function: "l2" delay_value: True diff --git a/benchmarl/conf/algorithm/ippo.yaml b/benchmarl/conf/algorithm/ippo.yaml index 59258794..3ed4bbd0 100644 --- a/benchmarl/conf/algorithm/ippo.yaml +++ b/benchmarl/conf/algorithm/ippo.yaml @@ -3,7 +3,6 @@ defaults: - _self_ -share_param_actor: True share_param_critic: True clip_epsilon: 0.2 entropy_coef: 0.0 diff --git a/benchmarl/conf/algorithm/iql.yaml b/benchmarl/conf/algorithm/iql.yaml index 0f4f26fc..65abeda8 100644 --- a/benchmarl/conf/algorithm/iql.yaml +++ b/benchmarl/conf/algorithm/iql.yaml @@ -5,4 +5,3 @@ defaults: delay_value: True loss_function: "l2" -share_params: True diff --git a/benchmarl/conf/algorithm/isac.yaml b/benchmarl/conf/algorithm/isac.yaml index 49f871bc..82c5d660 100644 --- a/benchmarl/conf/algorithm/isac.yaml +++ b/benchmarl/conf/algorithm/isac.yaml @@ -3,7 +3,6 @@ defaults: - _self_ -share_param_actor: True share_param_critic: True num_qvalue_nets: 2 diff --git a/benchmarl/conf/algorithm/maddpg.yaml b/benchmarl/conf/algorithm/maddpg.yaml index 99f47b0e..77e79416 100644 --- a/benchmarl/conf/algorithm/maddpg.yaml +++ b/benchmarl/conf/algorithm/maddpg.yaml @@ -3,7 +3,6 @@ defaults: - _self_ -share_param_actor: True share_param_critic: True loss_function: "l2" diff --git a/benchmarl/conf/algorithm/mappo.yaml b/benchmarl/conf/algorithm/mappo.yaml index 78aeeb99..bd1f18f3 100644 --- a/benchmarl/conf/algorithm/mappo.yaml +++ b/benchmarl/conf/algorithm/mappo.yaml @@ -3,7 +3,7 @@ defaults: - _self_ -share_param_actor: True + share_param_critic: True clip_epsilon: 0.2 entropy_coef: 0.0 diff --git a/benchmarl/conf/algorithm/masac.yaml b/benchmarl/conf/algorithm/masac.yaml index 1b76833d..08d3d778 100644 --- a/benchmarl/conf/algorithm/masac.yaml +++ b/benchmarl/conf/algorithm/masac.yaml @@ -3,7 +3,7 @@ defaults: - _self_ -share_param_actor: True + share_param_critic: True num_qvalue_nets: 2 diff --git a/benchmarl/conf/algorithm/qmix.yaml b/benchmarl/conf/algorithm/qmix.yaml index 95c2ebf4..0640e4f0 100644 --- a/benchmarl/conf/algorithm/qmix.yaml +++ b/benchmarl/conf/algorithm/qmix.yaml @@ -6,4 +6,3 @@ defaults: mixing_embed_dim: 32 delay_value: True loss_function: "l2" -share_params: True diff --git a/benchmarl/conf/algorithm/vdn.yaml b/benchmarl/conf/algorithm/vdn.yaml index 2f2e6fe0..f7630c67 100644 --- a/benchmarl/conf/algorithm/vdn.yaml +++ b/benchmarl/conf/algorithm/vdn.yaml @@ -4,4 +4,3 @@ defaults: delay_value: True loss_function: "l2" -share_params: True diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index dfc03df6..fb049a33 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -7,6 +7,7 @@ train_device: "cpu" gamma: 0.99 polyak_tau: 0.005 +share_policy_params: True lr: 0.00005 n_optimizer_steps: 45 clip_grad_norm: True diff --git a/benchmarl/environments/common.py b/benchmarl/environments/common.py index c297f8b4..76265527 100644 --- a/benchmarl/environments/common.py +++ b/benchmarl/environments/common.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import os import os.path as osp @@ -43,7 +45,7 @@ def __new__(cls, *args, **kwargs): def __init__(self, config: Dict[str, Any]): self.config = config - def update_config(self, config: Dict[str, Any]): + def update_config(self, config: Dict[str, Any]) -> Task: if self.config is None: self.config = config else: @@ -67,7 +69,7 @@ def supports_discrete_actions(self) -> bool: def max_steps(self, env: EnvBase) -> int: raise NotImplementedError - def has_render(self) -> bool: + def has_render(self, env: EnvBase) -> bool: raise NotImplementedError def group_map(self, env: EnvBase) -> Dict[str, List[str]]: @@ -90,7 +92,7 @@ def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]: @staticmethod def env_name() -> str: - return "vmas" + raise NotImplementedError @staticmethod def log_info(batch: TensorDictBase) -> Dict: @@ -108,7 +110,7 @@ def _load_from_yaml(name: str) -> Dict[str, Any]: yaml_path = Path(__file__).parent.parent / "conf" / "task" / f"{name}.yaml" return read_yaml_config(str(yaml_path.resolve())) - def get_from_yaml(self, path: Optional[str] = None): + def get_from_yaml(self, path: Optional[str] = None) -> Task: if path is None: task_name = self.name.lower() return self.update_config( diff --git a/benchmarl/environments/smacv2/common.py b/benchmarl/environments/smacv2/common.py index f4e36a4d..de7a0ad3 100644 --- a/benchmarl/environments/smacv2/common.py +++ b/benchmarl/environments/smacv2/common.py @@ -1,7 +1,6 @@ from typing import Callable, Dict, List, Optional import torch - from tensordict import TensorDictBase from torchrl.data import CompositeSpec from torchrl.envs import EnvBase @@ -19,7 +18,6 @@ def get_env_fun( continuous_actions: bool, seed: Optional[int], ) -> Callable[[], EnvBase]: - return lambda: SMACv2Env(categorical_actions=True, seed=seed, **self.config) def supports_continuous_actions(self) -> bool: @@ -28,7 +26,7 @@ def supports_continuous_actions(self) -> bool: def supports_discrete_actions(self) -> bool: return True - def has_render(self) -> bool: + def has_render(self, env: EnvBase) -> bool: return True def max_steps(self, env: EnvBase) -> bool: @@ -85,9 +83,3 @@ def log_info(batch: TensorDictBase) -> Dict: @staticmethod def env_name() -> str: return "smacv2" - - -if __name__ == "__main__": - print(Smacv2Task.protoss_5_vs_5.get_from_yaml()) - env = Smacv2Task.protoss_5_vs_5.get_env_fun(0, False, 0)() - print(env.render(mode="rgb_array")) diff --git a/benchmarl/environments/vmas/common.py b/benchmarl/environments/vmas/common.py index d399e0b6..23153964 100644 --- a/benchmarl/environments/vmas/common.py +++ b/benchmarl/environments/vmas/common.py @@ -33,7 +33,7 @@ def supports_continuous_actions(self) -> bool: def supports_discrete_actions(self) -> bool: return True - def has_render(self) -> bool: + def has_render(self, env: EnvBase) -> bool: return True def max_steps(self, env: EnvBase) -> bool: @@ -64,7 +64,3 @@ def action_spec(self, env: EnvBase) -> CompositeSpec: @staticmethod def env_name() -> str: return "vmas" - - -if __name__ == "__main__": - print(VmasTask.BALANCE.get_from_yaml()) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 3edd2222..be1bbb72 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -9,7 +9,6 @@ from typing import Dict, List, Optional import torch - from tensordict import TensorDictBase from tensordict.nn import TensorDictSequential from tensordict.utils import _unravel_key_to_tuple, unravel_key @@ -38,6 +37,7 @@ class ExperimentConfig: train_device: str = MISSING gamma: float = MISSING polyak_tau: float = MISSING + share_policy_params: bool = MISSING lr: float = MISSING n_optimizer_steps: int = MISSING collected_frames_per_batch: int = MISSING @@ -398,13 +398,13 @@ def _collection_loop(self): self._evaluation_loop(iter=self.n_iters_performed) # End of step + self.n_iters_performed += 1 self.logger.commit() if ( self.config.checkpoint_interval > 0 and self.n_iters_performed % self.config.checkpoint_interval == 0 ): self.save_trainer() - self.n_iters_performed += 1 sampling_start = time.time() self.close() @@ -462,10 +462,11 @@ def _grad_clip(self, optimizer: torch.optim.Optimizer) -> float: return float(gn) + @torch.no_grad() def _evaluation_loop(self, iter: int): evaluation_start = time.time() - with torch.no_grad() and set_exploration_type(ExplorationType.MODE): - if self.task.has_render(): + with set_exploration_type(ExplorationType.MODE): + if self.task.has_render(self.test_env): frames = [] def callback(env, td): diff --git a/benchmarl/experiment/logger.py b/benchmarl/experiment/logger.py index 9542dfa3..47267549 100644 --- a/benchmarl/experiment/logger.py +++ b/benchmarl/experiment/logger.py @@ -1,5 +1,4 @@ import json - from pathlib import Path from typing import Any, Dict, List, Optional diff --git a/benchmarl/hydra_run.py b/benchmarl/hydra_config.py similarity index 100% rename from benchmarl/hydra_run.py rename to benchmarl/hydra_config.py diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py index 468aae52..d4c7c03a 100644 --- a/benchmarl/models/common.py +++ b/benchmarl/models/common.py @@ -1,15 +1,15 @@ import pathlib + from abc import ABC, abstractmethod from dataclasses import asdict, dataclass from typing import Any, Callable, Dict, List, Optional, Sequence -from hydra.utils import get_class from tensordict import TensorDictBase from tensordict.nn import TensorDictModuleBase, TensorDictSequential from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs import EnvBase -from benchmarl.utils import DEVICE_TYPING, read_yaml_config +from benchmarl.utils import class_from_name, DEVICE_TYPING, read_yaml_config def _check_spec(tensordict, spec): @@ -22,7 +22,7 @@ def parse_model_config(cfg: Dict[str, Any]) -> Dict[str, Any]: kwargs = {} for key, value in cfg.items(): if key.endswith("class") and value is not None: - value = get_class(cfg[key]) + value = class_from_name(cfg[key]) kwargs.update({key: value}) return kwargs @@ -74,21 +74,27 @@ def output_has_agent_dim(self) -> bool: def _perform_checks(self): if not self.input_has_agent_dim and not self.centralised: - assert False + raise ValueError( + "If input does not have an agent dimension the model should be marked as centralised" + ) if len(self.in_keys) > 1: - assert False + raise ValueError("Currently models support just one input key") if len(self.out_keys) > 1: - assert False + raise ValueError("Currently models support just one output key") if self.agent_group in self.input_spec.keys() and self.input_spec[ self.agent_group ].shape != (self.n_agents,): - assert False + raise ValueError( + "If the agent group is in the input specs, its shape should be the number of agents" + ) if self.agent_group in self.output_spec.keys() and self.output_spec[ self.agent_group ].shape != (self.n_agents,): - assert False + raise ValueError( + "If the agent group is in the output specs, its shape should be the number of agents" + ) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # _check_spec(tensordict, self.input_spec) @@ -197,8 +203,14 @@ def get_model( ) -> Model: n_models = len(self.model_configs) - assert n_models > 0 - assert len(self.intermediate_sizes) == n_models - 1 + if not n_models > 0: + raise ValueError( + f"SequenceModelConfig expects n_models > 0, got {n_models}" + ) + if len(self.intermediate_sizes) != n_models - 1: + raise ValueError( + f"SequenceModelConfig intermediate_sizes len should be {n_models - 1}, got {len(self.intermediate_sizes)}" + ) out_has_agent_dim = output_has_agent_dim(share_params, centralised) next_centralised = not out_has_agent_dim @@ -246,6 +258,11 @@ def get_model( def associated_class(): return SequenceModel + def process_env_fun(self, env_fun: Callable[[], EnvBase]) -> Callable[[], EnvBase]: + for model_config in self.model_configs: + env_fun = model_config.process_env_fun(env_fun) + return env_fun + @staticmethod def get_from_yaml(path: Optional[str] = None): raise NotImplementedError diff --git a/benchmarl/models/mlp.py b/benchmarl/models/mlp.py index 88cd2bfd..8a76e6f9 100644 --- a/benchmarl/models/mlp.py +++ b/benchmarl/models/mlp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass, MISSING from typing import Optional, Sequence, Type @@ -56,12 +58,18 @@ def _perform_checks(self): super()._perform_checks() if self.input_has_agent_dim and self.input_leaf_spec.shape[-2] != self.n_agents: - assert False + raise ValueError( + "If the MLP input has the agent dimension," + " the second to last spec dimension should be the number of agents" + ) if ( self.output_has_agent_dim and self.output_leaf_spec.shape[-2] != self.n_agents ): - assert False + raise ValueError( + "If the MLP output has the agent dimension," + " the second to last spec dimension should be the number of agents" + ) def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Gather in_key @@ -71,6 +79,9 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.input_has_agent_dim: res = self.mlp.forward(input) if not self.output_has_agent_dim: + # If we are here the module is centralised and parameter shared. + # Thus the multi-agent dimension has been expanded, + # We remove it without loss of data res = res[..., 0, :] # Does not have multi-agent input dimension @@ -103,7 +114,7 @@ def associated_class(): return Mlp @staticmethod - def get_from_yaml(path: Optional[str] = None): + def get_from_yaml(path: Optional[str] = None) -> MlpConfig: if path is None: return MlpConfig( **ModelConfig._load_from_yaml( diff --git a/examples/simple_hydra_run.py b/benchmarl/run.py similarity index 82% rename from examples/simple_hydra_run.py rename to benchmarl/run.py index bc5a941e..2328653e 100644 --- a/examples/simple_hydra_run.py +++ b/benchmarl/run.py @@ -1,10 +1,11 @@ import hydra -from benchmarl.hydra_run import load_experiment_from_hydra from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf +from benchmarl.hydra_config import load_experiment_from_hydra -@hydra.main(version_base=None, config_path="../benchmarl/conf", config_name="config") + +@hydra.main(version_base=None, config_path="conf", config_name="config") def hydra_experiment(cfg: DictConfig) -> None: hydra_choices = HydraConfig.get().runtime.choices task_name = hydra_choices.task diff --git a/benchmarl/utils.py b/benchmarl/utils.py index 8564e53f..a5444d49 100644 --- a/benchmarl/utils.py +++ b/benchmarl/utils.py @@ -1,3 +1,4 @@ +import importlib from typing import Any, Dict, Union import torch @@ -13,3 +14,14 @@ def read_yaml_config(config_file: str) -> Dict[str, Any]: if "defaults" in config_dict.keys(): del config_dict["defaults"] return config_dict + + +def class_from_name(name: str): + name_split = name.split(".") + module_name = ".".join(name_split[:-1]) + class_name = name_split[-1] + # load the module, will raise ImportError if module cannot be loaded + m = importlib.import_module(module_name) + # get the class, will raise AttributeError if class cannot be found + c = getattr(m, class_name) + return c diff --git a/examples/vmas_run.py b/examples/vmas_run.py index 21678c8a..60059e1d 100644 --- a/examples/vmas_run.py +++ b/examples/vmas_run.py @@ -1,7 +1,7 @@ import hydra from benchmarl.experiment import Experiment -from benchmarl.hydra_run import ( +from benchmarl.hydra_config import ( load_algorithm_config_from_hydra, load_experiment_config_from_hydra, load_model_config_from_hydra, diff --git a/requirements.txt b/requirements.txt index e2d09241..78620c47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -hydra-core +tqdm diff --git a/setup.py b/setup.py index 11332749..b9c73b3c 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,6 @@ author="Matteo Bettini", author_email="mb2389@cl.cam.ac.uk", packages=find_packages(), - install_requires=["torchrl"], + install_requires=["torchrl", "tqdm"], include_package_data=True, ) diff --git a/test/conftest.py b/test/conftest.py index 6328962e..26670db1 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,6 +1,33 @@ +import os +import shutil +from pathlib import Path + import pytest from benchmarl.experiment import ExperimentConfig +from benchmarl.models import MlpConfig +from benchmarl.models.common import ModelConfig, SequenceModelConfig +from torch import nn + + +def pytest_sessionstart(session): + """ + Called after the Session object has been created and + before performing collection and entering the run test loop. + """ + folder_name = Path(os.getcwd()) + folder_name = folder_name / "tmp" + folder_name.mkdir(parents=False, exist_ok=True) + os.chdir(folder_name) + + +def pytest_sessionfinish(session, exitstatus): + """ + Called after whole test run finished, right before + returning the exit status to the system. + """ + folder_name = Path(os.getcwd()) / "tmp" + shutil.rmtree(folder_name) @pytest.fixture @@ -13,8 +40,20 @@ def experiment_config() -> ExperimentConfig: experiment_config.on_policy_minibatch_size = 10 experiment_config.off_policy_memory_size = 200 experiment_config.off_policy_train_batch_size = 100 - experiment_config.evaluation = False - experiment_config.loggers = [] - experiment_config.create_json = False - experiment_config.checkpoint_interval = 0 + experiment_config.evaluation = True + experiment_config.evaluation_episodes = 2 + experiment_config.loggers = ["csv"] + experiment_config.create_json = True + experiment_config.checkpoint_interval = 1 return experiment_config + + +@pytest.fixture +def mlp_sequence_config() -> ModelConfig: + return SequenceModelConfig( + model_configs=[ + MlpConfig(num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear), + MlpConfig(num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear), + ], + intermediate_sizes=[5], + ) diff --git a/test/test_algorithm.py b/test/test_algorithm.py new file mode 100644 index 00000000..22dd6f56 --- /dev/null +++ b/test/test_algorithm.py @@ -0,0 +1,20 @@ +import pytest + +from benchmarl.algorithms import algorithm_config_registry +from benchmarl.algorithms.common import AlgorithmConfig +from benchmarl.hydra_config import load_algorithm_config_from_hydra +from hydra import compose, initialize + + +@pytest.mark.parametrize("algo_name", algorithm_config_registry.keys()) +def test_loading_algorithms(algo_name): + with initialize(version_base=None, config_path="../benchmarl/conf"): + cfg = compose( + config_name="config", + overrides=[ + f"algorithm={algo_name}", + "task=vmas/balance", + ], + ) + algo_config: AlgorithmConfig = load_algorithm_config_from_hydra(cfg.algorithm) + assert algo_config == algorithm_config_registry[algo_name].get_from_yaml() diff --git a/test/test_algorithms.py b/test/test_algorithms.py deleted file mode 100644 index 1210fbb6..00000000 --- a/test/test_algorithms.py +++ /dev/null @@ -1,85 +0,0 @@ -import importlib - -import pytest -from benchmarl.algorithms import algorithm_config_registry - -from benchmarl.algorithms.common import AlgorithmConfig - -from benchmarl.environments import Smacv2Task, VmasTask -from benchmarl.experiment import Experiment -from benchmarl.models.common import SequenceModelConfig -from benchmarl.models.mlp import MlpConfig -from torch import nn - - -_has_vmas = importlib.util.find_spec("vmas") is not None -_has_smacv2 = importlib.util.find_spec("smacv2") is not None - - -@pytest.mark.skipif(not _has_vmas, reason="VMAS not found") -@pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) -@pytest.mark.parametrize("continuous", [True, False]) -def test_all_algos_vmas(algo_config, continuous, experiment_config): - task = VmasTask.BALANCE.get_from_yaml() - model_config = SequenceModelConfig( - model_configs=[ - MlpConfig(num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear), - MlpConfig(num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear), - ], - intermediate_sizes=[5], - ) - experiment_config.prefer_continuous_actions = continuous - - experiment = Experiment( - algorithm_config=algo_config.get_from_yaml(), - model_config=model_config, - seed=0, - config=experiment_config, - task=task, - ) - experiment.run() - - -@pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found") -@pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) -def test_all_algos_smac(algo_config: AlgorithmConfig, experiment_config): - if algo_config.supports_discrete_actions(): - task = Smacv2Task.protoss_5_vs_5.get_from_yaml() - model_config = SequenceModelConfig( - model_configs=[ - MlpConfig( - num_cells=[8], activation_class=nn.Tanh, layer_class=nn.Linear - ), - MlpConfig( - num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear - ), - ], - intermediate_sizes=[5], - ) - - experiment = Experiment( - algorithm_config=algo_config.get_from_yaml(), - model_config=model_config, - seed=0, - config=experiment_config, - task=task, - ) - experiment.run() - - -# @pytest.mark.parametrize("algo_config", algorithm_config_registry.keys()) -# def test_all_algos_hydra(algo_config): -# with initialize(version_base=None, config_path="../benchmarl/conf"): -# cfg = compose( -# config_name="config", -# overrides=[ -# f"algorithm={algo_config}", -# "task=vmas/balance", -# "model.num_cells=[3]", -# "experiment.loggers=[]", -# ], -# return_hydra_config=True, -# ) -# task_name = cfg.hydra.runtime.choices.task -# experiment = load_experiment_from_hydra_config(cfg, task_name=task_name) -# experiment.run() diff --git a/test/test_models.py b/test/test_models.py new file mode 100644 index 00000000..0119ef6d --- /dev/null +++ b/test/test_models.py @@ -0,0 +1,45 @@ +import pytest + +from benchmarl.hydra_config import load_model_config_from_hydra +from benchmarl.models import model_config_registry + +from benchmarl.models.common import SequenceModelConfig +from hydra import compose, initialize + + +@pytest.mark.parametrize("model_name", model_config_registry.keys()) +def test_loading_simple_models(model_name): + with initialize(version_base=None, config_path="../benchmarl/conf"): + cfg = compose( + config_name="config", + overrides=[ + "algorithm=mappo", + "task=vmas/balance", + f"model=layers/{model_name}", + ], + ) + model_config = load_model_config_from_hydra(cfg.model) + assert model_config == model_config_registry[model_name].get_from_yaml() + + +@pytest.mark.parametrize("model_name", model_config_registry.keys()) +def test_loading_sequence_models(model_name, intermidiate_size=10): + with initialize(version_base=None, config_path="../benchmarl/conf"): + cfg = compose( + config_name="config", + overrides=[ + "algorithm=mappo", + "task=vmas/balance", + "model=sequence", + f"model/layers@model.layers.l1={model_name}", + f"model/layers@model.layers.l2={model_name}", + f"model.intermediate_sizes={[intermidiate_size]}", + ], + ) + hydra_model_config = load_model_config_from_hydra(cfg.model) + layer_config = model_config_registry[model_name].get_from_yaml() + yaml_config = SequenceModelConfig( + model_configs=[layer_config, layer_config], + intermediate_sizes=[intermidiate_size], + ) + assert hydra_model_config == yaml_config diff --git a/test/test_smacv2.py b/test/test_smacv2.py new file mode 100644 index 00000000..60cc56b1 --- /dev/null +++ b/test/test_smacv2.py @@ -0,0 +1,50 @@ +import importlib + +import pytest + +from benchmarl.algorithms import ( + algorithm_config_registry, + MappoConfig, + MasacConfig, + QmixConfig, +) +from benchmarl.algorithms.common import AlgorithmConfig +from benchmarl.environments import Smacv2Task +from benchmarl.experiment import Experiment + +_has_smacv2 = importlib.util.find_spec("smacv2") is not None + + +@pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found") +class TestSmacv2: + @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) + @pytest.mark.parametrize("task", list(Smacv2Task)) + def test_all_algos( + self, algo_config: AlgorithmConfig, task, experiment_config, mlp_sequence_config + ): + if algo_config.supports_discrete_actions(): + task = task.get_from_yaml() + + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=mlp_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + + @pytest.mark.parametrize("algo_config", [QmixConfig, MappoConfig, MasacConfig]) + @pytest.mark.parametrize("task", list(Smacv2Task)) + def test_all_tasks( + self, algo_config: AlgorithmConfig, task, experiment_config, mlp_sequence_config + ): + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=mlp_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() diff --git a/test/test_task.py b/test/test_task.py new file mode 100644 index 00000000..db4b2eab --- /dev/null +++ b/test/test_task.py @@ -0,0 +1,21 @@ +import pytest + +from benchmarl.environments import Task, task_config_registry +from benchmarl.hydra_config import load_task_config_from_hydra +from hydra import compose, initialize + + +@pytest.mark.parametrize("task_name", task_config_registry.keys()) +def test_loading_tasks(task_name): + with initialize(version_base=None, config_path="../benchmarl/conf"): + cfg = compose( + config_name="config", + overrides=[ + "algorithm=mappo", + f"task={task_name}", + ], + return_hydra_config=True, + ) + task_name_hydra = cfg.hydra.runtime.choices.task + task: Task = load_task_config_from_hydra(cfg.task, task_name=task_name_hydra) + assert task == task_config_registry[task_name].get_from_yaml() diff --git a/test/test_vmas.py b/test/test_vmas.py new file mode 100644 index 00000000..07141a4c --- /dev/null +++ b/test/test_vmas.py @@ -0,0 +1,83 @@ +import importlib + +import pytest + +from benchmarl.algorithms import ( + algorithm_config_registry, + IppoConfig, + MaddpgConfig, + MappoConfig, + MasacConfig, + QmixConfig, +) +from benchmarl.algorithms.common import AlgorithmConfig +from benchmarl.environments import Task, VmasTask +from benchmarl.experiment import Experiment +from utils_experiment import ExperimentUtils + +_has_vmas = importlib.util.find_spec("vmas") is not None + + +@pytest.mark.skipif(not _has_vmas, reason="VMAS not found") +class TestVmas: + @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) + @pytest.mark.parametrize("continuous", [True, False]) + @pytest.mark.parametrize("task", list(VmasTask)) + def test_all_algos_all_tasks( + self, + algo_config: AlgorithmConfig, + task: Task, + continuous, + experiment_config, + mlp_sequence_config, + ): + task = task.get_from_yaml() + experiment_config.prefer_continuous_actions = continuous + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=mlp_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + + @pytest.mark.parametrize("algo_config", [MappoConfig, QmixConfig]) + @pytest.mark.parametrize("task", [VmasTask.BALANCE, VmasTask.SAMPLING]) + def test_reloading_trainer( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + mlp_sequence_config, + ): + ExperimentUtils.check_experiment_loading( + algo_config=algo_config.get_from_yaml(), + model_config=mlp_sequence_config, + experiment_config=experiment_config, + task=task.get_from_yaml(), + ) + + @pytest.mark.parametrize( + "algo_config", [QmixConfig, IppoConfig, MaddpgConfig, MasacConfig] + ) + @pytest.mark.parametrize("task", [VmasTask.NAVIGATION]) + @pytest.mark.parametrize("share_params", [True, False]) + def test_share_policy_params( + self, + algo_config: AlgorithmConfig, + task: Task, + share_params, + experiment_config, + mlp_sequence_config, + ): + experiment_config.share_policy_params = share_params + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=mlp_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() diff --git a/test/utils_experiment.py b/test/utils_experiment.py new file mode 100644 index 00000000..053cde74 --- /dev/null +++ b/test/utils_experiment.py @@ -0,0 +1,52 @@ +from benchmarl.algorithms.common import AlgorithmConfig +from benchmarl.environments import Task +from benchmarl.experiment import Experiment, ExperimentConfig +from benchmarl.models.common import ModelConfig + + +class ExperimentUtils: + @staticmethod + def check_experiment_loading( + algo_config: AlgorithmConfig, + task: Task, + experiment_config: ExperimentConfig, + model_config: ModelConfig, + ): + n_iters = experiment_config.n_iters + experiment = Experiment( + algorithm_config=algo_config, + model_config=model_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + + policy = experiment.policy + losses = experiment.losses + exp_folder = experiment.folder_name + + experiment_config.n_iters = n_iters + 3 + experiment_config.restore_file = ( + exp_folder / "checkpoints" / f"checkpoint_{n_iters}.pt" + ) + experiment = Experiment( + algorithm_config=algo_config, + model_config=model_config, + seed=0, + config=experiment_config, + task=task, + ) + for param1, param2 in zip( + list(experiment.policy.parameters()), list(policy.parameters()) + ): + assert (param1 == param2).all() + for loss1, loss2 in zip(experiment.losses.values(), losses.values()): + for param1, param2 in zip( + list(loss1.parameters()), list(loss2.parameters()) + ): + assert (param1 == param2).all() + assert experiment.n_iters_performed == n_iters + assert experiment.folder_name == exp_folder + experiment.run() + assert experiment.n_iters_performed == n_iters + 3