From bd7c03017514ba7a5009425563341862546a325b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Thu, 26 Dec 2024 22:40:15 -0500 Subject: [PATCH] Add agent tests in jax --- tests/jax/test_jax_agent_a2c.py | 185 +++++++++++++++++++++++++ tests/jax/test_jax_agent_cem.py | 142 +++++++++++++++++++ tests/jax/test_jax_agent_ddpg.py | 199 ++++++++++++++++++++++++++ tests/jax/test_jax_agent_ddqn.py | 174 +++++++++++++++++++++++ tests/jax/test_jax_agent_dqn.py | 174 +++++++++++++++++++++++ tests/jax/test_jax_agent_ppo.py | 207 ++++++++++++++++++++++++++++ tests/jax/test_jax_agent_rpo.py | 199 ++++++++++++++++++++++++++ tests/jax/test_jax_agent_sac.py | 195 ++++++++++++++++++++++++++ tests/jax/test_jax_agent_td3.py | 230 +++++++++++++++++++++++++++++++ 9 files changed, 1705 insertions(+) create mode 100644 tests/jax/test_jax_agent_a2c.py create mode 100644 tests/jax/test_jax_agent_cem.py create mode 100644 tests/jax/test_jax_agent_ddpg.py create mode 100644 tests/jax/test_jax_agent_ddqn.py create mode 100644 tests/jax/test_jax_agent_dqn.py create mode 100644 tests/jax/test_jax_agent_ppo.py create mode 100644 tests/jax/test_jax_agent_rpo.py create mode 100644 tests/jax/test_jax_agent_sac.py create mode 100644 tests/jax/test_jax_agent_td3.py diff --git a/tests/jax/test_jax_agent_a2c.py b/tests/jax/test_jax_agent_a2c.py new file mode 100644 index 00000000..c90787fd --- /dev/null +++ b/tests/jax/test_jax_agent_a2c.py @@ -0,0 +1,185 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.a2c import A2C as Agent +from skrl.agents.jax.a2c import A2C_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + for k in default_config.keys(): + assert k in config + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + mini_batches=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + lambda_=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + value_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + grad_norm_clip=st.floats(min_value=0, max_value=1), + entropy_loss_scale=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + time_limit_bootstrap=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("separate", [True]) +@pytest.mark.parametrize("policy_structure", ["GaussianMixin", "CategoricalMixin"]) +def test_agent( + capsys, + device, + num_envs, + # model config + separate, + policy_structure, + # agent config + rollouts, + mini_batches, + discount_factor, + lambda_, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + value_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + entropy_loss_scale, + rewards_shaper, + time_limit_bootstrap, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + if policy_structure in ["GaussianMixin"]: + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + elif policy_structure == "CategoricalMixin": + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + if separate: + if policy_structure == "GaussianMixin": + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + elif policy_structure == "CategoricalMixin": + models["policy"] = categorical_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["value"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + else: + raise NotImplementedError + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "mini_batches": mini_batches, + "discount_factor": discount_factor, + "lambda": lambda_, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "value_preprocessor": value_preprocessor, + "value_preprocessor_kwargs": {"size": 1, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "entropy_loss_scale": entropy_loss_scale, + "rewards_shaper": rewards_shaper, + "time_limit_bootstrap": time_limit_bootstrap, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_cem.py b/tests/jax/test_jax_agent_cem.py new file mode 100644 index 00000000..9ae7d2e9 --- /dev/null +++ b/tests/jax/test_jax_agent_cem.py @@ -0,0 +1,142 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.cem import CEM as Agent +from skrl.agents.jax.cem import CEM_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import categorical_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + for k in default_config.keys(): + assert k in config + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + percentile=st.floats(min_value=0, max_value=1), + discount_factor=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + rollouts, + percentile, + discount_factor, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = categorical_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "percentile": percentile, + "discount_factor": discount_factor, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_ddpg.py b/tests/jax/test_jax_agent_ddpg.py new file mode 100644 index 00000000..631ac3f0 --- /dev/null +++ b/tests/jax/test_jax_agent_ddpg.py @@ -0,0 +1,199 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.ddpg import DDPG as Agent +from skrl.agents.jax.ddpg import DDPG_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.noises.jax import GaussianNoise, OrnsteinUhlenbeckNoise +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + actor_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + critic_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.integers(min_value=0, max_value=5), + learning_starts=st.integers(min_value=0, max_value=5), + grad_norm_clip=st.floats(min_value=0, max_value=1), + exploration=st.one_of(st.none(), st.just(OrnsteinUhlenbeckNoise), st.just(GaussianNoise)), + exploration_initial_scale=st.floats(min_value=0, max_value=1), + exploration_final_scale=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + actor_learning_rate, + critic_learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + exploration, + exploration_initial_scale, + exploration_final_scale, + exploration_timesteps, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["critic"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "actor_learning_rate": actor_learning_rate, + "critic_learning_rate": critic_learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "exploration": { + "initial_scale": exploration_initial_scale, + "final_scale": exploration_final_scale, + "timesteps": exploration_timesteps, + }, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + # noise + # - exploration + if exploration is None: + cfg["exploration"]["noise"] = None + elif exploration is OrnsteinUhlenbeckNoise: + cfg["exploration"]["noise"] = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=1.0, device=env.device) + elif exploration is GaussianNoise: + cfg["exploration"]["noise"] = GaussianNoise(mean=0, std=0.1, device=env.device) + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_ddqn.py b/tests/jax/test_jax_agent_ddqn.py new file mode 100644 index 00000000..0b5fddb2 --- /dev/null +++ b/tests/jax/test_jax_agent_ddqn.py @@ -0,0 +1,174 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.dqn import DDQN as Agent +from skrl.agents.jax.dqn import DDQN_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.integers(min_value=0, max_value=5), + update_interval=st.integers(min_value=1, max_value=3), + target_update_interval=st.integers(min_value=1, max_value=5), + exploration_initial_epsilon=st.floats(min_value=0, max_value=1), + exploration_final_epsilon=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + update_interval, + target_update_interval, + exploration_initial_epsilon, + exploration_final_epsilon, + exploration_timesteps, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "update_interval": update_interval, + "target_update_interval": target_update_interval, + "exploration": { + "initial_epsilon": exploration_initial_epsilon, + "final_epsilon": exploration_final_epsilon, + "timesteps": exploration_timesteps, + }, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_dqn.py b/tests/jax/test_jax_agent_dqn.py new file mode 100644 index 00000000..bfc93fda --- /dev/null +++ b/tests/jax/test_jax_agent_dqn.py @@ -0,0 +1,174 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.dqn import DQN as Agent +from skrl.agents.jax.dqn import DQN_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.integers(min_value=0, max_value=5), + update_interval=st.integers(min_value=1, max_value=3), + target_update_interval=st.integers(min_value=1, max_value=5), + exploration_initial_epsilon=st.floats(min_value=0, max_value=1), + exploration_final_epsilon=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + update_interval, + target_update_interval, + exploration_initial_epsilon, + exploration_final_epsilon, + exploration_timesteps, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "update_interval": update_interval, + "target_update_interval": target_update_interval, + "exploration": { + "initial_epsilon": exploration_initial_epsilon, + "final_epsilon": exploration_final_epsilon, + "timesteps": exploration_timesteps, + }, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_ppo.py b/tests/jax/test_jax_agent_ppo.py new file mode 100644 index 00000000..27bafaff --- /dev/null +++ b/tests/jax/test_jax_agent_ppo.py @@ -0,0 +1,207 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.ppo import PPO as Agent +from skrl.agents.jax.ppo import PPO_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + learning_epochs=st.integers(min_value=1, max_value=5), + mini_batches=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + lambda_=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + value_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + grad_norm_clip=st.floats(min_value=0, max_value=1), + ratio_clip=st.floats(min_value=0, max_value=1), + value_clip=st.floats(min_value=0, max_value=1), + clip_predicted_values=st.booleans(), + entropy_loss_scale=st.floats(min_value=0, max_value=1), + value_loss_scale=st.floats(min_value=0, max_value=1), + kl_threshold=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + time_limit_bootstrap=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("separate", [True]) +@pytest.mark.parametrize("policy_structure", ["GaussianMixin", "CategoricalMixin"]) +def test_agent( + capsys, + device, + num_envs, + # model config + separate, + policy_structure, + # agent config + rollouts, + learning_epochs, + mini_batches, + discount_factor, + lambda_, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + value_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + ratio_clip, + value_clip, + clip_predicted_values, + entropy_loss_scale, + value_loss_scale, + kl_threshold, + rewards_shaper, + time_limit_bootstrap, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + if policy_structure in ["GaussianMixin"]: + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + elif policy_structure == "CategoricalMixin": + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + if separate: + if policy_structure == "GaussianMixin": + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + elif policy_structure == "CategoricalMixin": + models["policy"] = categorical_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["value"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + else: + raise NotImplementedError + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "learning_epochs": learning_epochs, + "mini_batches": mini_batches, + "discount_factor": discount_factor, + "lambda": lambda_, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "value_preprocessor": value_preprocessor, + "value_preprocessor_kwargs": {"size": 1, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "ratio_clip": ratio_clip, + "value_clip": value_clip, + "clip_predicted_values": clip_predicted_values, + "entropy_loss_scale": entropy_loss_scale, + "value_loss_scale": value_loss_scale, + "kl_threshold": kl_threshold, + "rewards_shaper": rewards_shaper, + "time_limit_bootstrap": time_limit_bootstrap, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_rpo.py b/tests/jax/test_jax_agent_rpo.py new file mode 100644 index 00000000..6fec0f01 --- /dev/null +++ b/tests/jax/test_jax_agent_rpo.py @@ -0,0 +1,199 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.rpo import RPO as Agent +from skrl.agents.jax.rpo import RPO_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + learning_epochs=st.integers(min_value=1, max_value=5), + mini_batches=st.integers(min_value=1, max_value=5), + alpha=st.floats(min_value=0, max_value=1), + discount_factor=st.floats(min_value=0, max_value=1), + lambda_=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + value_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + grad_norm_clip=st.floats(min_value=0, max_value=1), + ratio_clip=st.floats(min_value=0, max_value=1), + value_clip=st.floats(min_value=0, max_value=1), + clip_predicted_values=st.booleans(), + entropy_loss_scale=st.floats(min_value=0, max_value=1), + value_loss_scale=st.floats(min_value=0, max_value=1), + kl_threshold=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + time_limit_bootstrap=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("separate", [True]) +@pytest.mark.parametrize("policy_structure", ["GaussianMixin"]) +def test_agent( + capsys, + device, + num_envs, + # model config + separate, + policy_structure, + # agent config + rollouts, + learning_epochs, + mini_batches, + alpha, + discount_factor, + lambda_, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + value_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + ratio_clip, + value_clip, + clip_predicted_values, + entropy_loss_scale, + value_loss_scale, + kl_threshold, + rewards_shaper, + time_limit_bootstrap, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + if separate: + if policy_structure == "GaussianMixin": + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["value"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + else: + raise NotImplementedError + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "learning_epochs": learning_epochs, + "mini_batches": mini_batches, + "alpha": alpha, + "discount_factor": discount_factor, + "lambda": lambda_, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "value_preprocessor": value_preprocessor, + "value_preprocessor_kwargs": {"size": 1, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "ratio_clip": ratio_clip, + "value_clip": value_clip, + "clip_predicted_values": clip_predicted_values, + "entropy_loss_scale": entropy_loss_scale, + "value_loss_scale": value_loss_scale, + "kl_threshold": kl_threshold, + "rewards_shaper": rewards_shaper, + "time_limit_bootstrap": time_limit_bootstrap, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_sac.py b/tests/jax/test_jax_agent_sac.py new file mode 100644 index 00000000..112257a4 --- /dev/null +++ b/tests/jax/test_jax_agent_sac.py @@ -0,0 +1,195 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.sac import SAC as Agent +from skrl.agents.jax.sac import SAC_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + actor_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + critic_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.integers(min_value=0, max_value=5), + learning_starts=st.integers(min_value=0, max_value=5), + grad_norm_clip=st.floats(min_value=0, max_value=1), + learn_entropy=st.booleans(), + entropy_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + initial_entropy_value=st.floats(min_value=0, max_value=1), + target_entropy=st.one_of(st.none(), st.floats(min_value=-1, max_value=1)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + actor_learning_rate, + critic_learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + learn_entropy, + entropy_learning_rate, + initial_entropy_value, + target_entropy, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "actor_learning_rate": actor_learning_rate, + "critic_learning_rate": critic_learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "learn_entropy": learn_entropy, + "entropy_learning_rate": entropy_learning_rate, + "initial_entropy_value": initial_entropy_value, + "target_entropy": target_entropy, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_td3.py b/tests/jax/test_jax_agent_td3.py new file mode 100644 index 00000000..b1ec36f1 --- /dev/null +++ b/tests/jax/test_jax_agent_td3.py @@ -0,0 +1,230 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.td3 import TD3 as Agent +from skrl.agents.jax.td3 import TD3_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.noises.jax import GaussianNoise, OrnsteinUhlenbeckNoise +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + actor_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + critic_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.integers(min_value=0, max_value=5), + learning_starts=st.integers(min_value=0, max_value=5), + grad_norm_clip=st.floats(min_value=0, max_value=1), + exploration=st.one_of(st.none(), st.just(OrnsteinUhlenbeckNoise), st.just(GaussianNoise)), + exploration_initial_scale=st.floats(min_value=0, max_value=1), + exploration_final_scale=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + policy_delay=st.integers(min_value=1, max_value=3), + smooth_regularization_noise=st.one_of(st.none(), st.just(OrnsteinUhlenbeckNoise), st.just(GaussianNoise)), + smooth_regularization_clip=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + actor_learning_rate, + critic_learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + exploration, + exploration_initial_scale, + exploration_final_scale, + exploration_timesteps, + policy_delay, + smooth_regularization_noise, + smooth_regularization_clip, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "actor_learning_rate": actor_learning_rate, + "critic_learning_rate": critic_learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "exploration": { + "initial_scale": exploration_initial_scale, + "final_scale": exploration_final_scale, + "timesteps": exploration_timesteps, + }, + "policy_delay": policy_delay, + "smooth_regularization_clip": smooth_regularization_clip, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + # noise + # - exploration + if exploration is None: + cfg["exploration"]["noise"] = None + elif exploration is OrnsteinUhlenbeckNoise: + cfg["exploration"]["noise"] = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=1.0, device=env.device) + elif exploration is GaussianNoise: + cfg["exploration"]["noise"] = GaussianNoise(mean=0, std=0.1, device=env.device) + # - regularization + if smooth_regularization_noise is None: + cfg["smooth_regularization_noise"] = None + elif smooth_regularization_noise is OrnsteinUhlenbeckNoise: + cfg["smooth_regularization_noise"] = OrnsteinUhlenbeckNoise( + theta=0.1, sigma=0.2, base_scale=1.0, device=env.device + ) + elif smooth_regularization_noise is GaussianNoise: + cfg["smooth_regularization_noise"] = GaussianNoise(mean=0, std=0.1, device=env.device) + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train()