diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index d1eaed19..dfc03df6 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -4,18 +4,20 @@ defaults: sampling_device: "cpu" train_device: "cpu" + gamma: 0.99 polyak_tau: 0.005 lr: 0.00005 n_optimizer_steps: 45 -collected_frames_per_batch: 60_000 -n_envs_per_worker: 600 -n_iters: 500 -prefer_continuous_actions: True clip_grad_norm: True -clip_grad_val: 40 +clip_grad_val: 5 +prefer_continuous_actions: True + +collected_frames_per_batch: 6000 +n_envs_per_worker: 10 +n_iters: 500 -on_policy_minibatch_size: 4096 +on_policy_minibatch_size: 400 off_policy_memory_size: 100_000 off_policy_train_batch_size: 10_000 @@ -23,11 +25,11 @@ off_policy_prioritised_alpha: 0.7 off_policy_prioritised_beta: 0.5 evaluation: True -evaluation_interval: 30 -evaluation_episodes: 200 +evaluation_interval: 20 +evaluation_episodes: 10 loggers: [wandb] create_json: True restore_file: null -checkpoint_interval: 100 +checkpoint_interval: 50 diff --git a/test/conf/experiment/base_experiment.yaml b/test/conf/experiment/base_experiment.yaml deleted file mode 100644 index 446dd5dd..00000000 --- a/test/conf/experiment/base_experiment.yaml +++ /dev/null @@ -1,33 +0,0 @@ -defaults: - - experiment_config - - _self_ - -sampling_device: "cpu" -train_device: "cpu" -gamma: 0.99 -polyak_tau: 0.005 -lr: 0.00005 -n_optimizer_steps: 10 -collected_frames_per_batch: 100 -n_envs_per_worker: 2 -n_iters: 3 -prefer_continuous_actions: True -clip_grad_norm: True -clip_grad_val: 40 - -on_policy_minibatch_size: 100 - -off_policy_memory_size: 200 -off_policy_train_batch_size: 100 -off_policy_prioritised_alpha: 0.7 -off_policy_prioritised_beta: 0.5 - -evaluation: False -evaluation_interval: 20 -evaluation_episodes: 200 - -loggers: [] -create_json: False - -restore_file: -checkpoint_interval: 0 diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..6328962e --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,20 @@ +import pytest + +from benchmarl.experiment import ExperimentConfig + + +@pytest.fixture +def experiment_config() -> ExperimentConfig: + experiment_config: ExperimentConfig = ExperimentConfig.get_from_yaml() + experiment_config.n_iters = 3 + experiment_config.n_optimizer_steps = 2 + experiment_config.n_envs_per_worker = 2 + experiment_config.collected_frames_per_batch = 100 + experiment_config.on_policy_minibatch_size = 10 + experiment_config.off_policy_memory_size = 200 + experiment_config.off_policy_train_batch_size = 100 + experiment_config.evaluation = False + experiment_config.loggers = [] + experiment_config.create_json = False + experiment_config.checkpoint_interval = 0 + return experiment_config diff --git a/test/test_algorithms.py b/test/test_algorithms.py index b68524c7..2c4c034b 100644 --- a/test/test_algorithms.py +++ b/test/test_algorithms.py @@ -1,18 +1,22 @@ -import pathlib +import importlib import pytest - from benchmarl.algorithms import algorithm_config_registry + from benchmarl.environments import VmasTask -from benchmarl.experiment import Experiment, ExperimentConfig +from benchmarl.experiment import Experiment from benchmarl.models.common import SequenceModelConfig from benchmarl.models.mlp import MlpConfig from torch import nn +_has_vmas = importlib.util.find_spec("vmas") is not None + + +@pytest.mark.skipif(not _has_vmas, reason="VMAS not found") @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) @pytest.mark.parametrize("continuous", [True, False]) -def test_all_algos_balance(algo_config, continuous): +def test_all_algos_vmas(algo_config, continuous, experiment_config): task = VmasTask.BALANCE.get_from_yaml() model_config = SequenceModelConfig( model_configs=[ @@ -21,14 +25,6 @@ def test_all_algos_balance(algo_config, continuous): ], intermediate_sizes=[5], ) - experiment_config: ExperimentConfig = ExperimentConfig.get_from_yaml( - str( - pathlib.Path(__file__).parent - / "conf" - / "experiment" - / "base_experiment.yaml" - ) - ) experiment_config.prefer_continuous_actions = continuous experiment = Experiment(