From 2882db43266e61e0351b058d1f2af7d5bc16a81a Mon Sep 17 00:00:00 2001 From: Toni-SM Date: Wed, 1 Jan 2025 00:03:57 -0500 Subject: [PATCH] Automatic mixed precision training in PyTorch (#243) * Update PPO mixed-precision implementation in torch * Add A2C mixed precision support in torch * Add AMP mixed precision support in torch * Add CEM mixed precision support in torch * Add DDPG mixed precision support in torch * Add DQN and DDQN mixed precision support in torch * Add RPO mixed precision support in torch * Add SAC mixed precision support in torch * Add TD3 mixed precision support in torch * Update docs * Add PPO test in torch * Add agent tests in torch * Add agent tests in jax * Avoid TypeError: Got unsupported ScalarType BFloat16 * Update CHANGELOG --- CHANGELOG.md | 1 + docs/source/api/agents/a2c.rst | 4 + docs/source/api/agents/amp.rst | 4 + docs/source/api/agents/cem.rst | 4 + docs/source/api/agents/ddpg.rst | 4 + docs/source/api/agents/ddqn.rst | 4 + docs/source/api/agents/dqn.rst | 4 + docs/source/api/agents/ppo.rst | 4 + docs/source/api/agents/rpo.rst | 4 + docs/source/api/agents/sac.rst | 4 + docs/source/api/agents/td3.rst | 4 + docs/source/api/agents/trpo.rst | 4 + skrl/agents/torch/a2c/a2c.py | 75 ++++---- skrl/agents/torch/a2c/a2c_rnn.py | 77 +++++---- skrl/agents/torch/amp/amp.py | 212 ++++++++++++----------- skrl/agents/torch/cem/cem.py | 29 +++- skrl/agents/torch/ddpg/ddpg.py | 72 +++++--- skrl/agents/torch/ddpg/ddpg_rnn.py | 79 ++++++--- skrl/agents/torch/dqn/ddqn.py | 72 +++++--- skrl/agents/torch/dqn/dqn.py | 56 ++++-- skrl/agents/torch/ppo/ppo.py | 16 +- skrl/agents/torch/ppo/ppo_rnn.py | 99 ++++++----- skrl/agents/torch/rpo/rpo.py | 104 ++++++----- skrl/agents/torch/rpo/rpo_rnn.py | 114 +++++++----- skrl/agents/torch/sac/sac.py | 113 +++++++----- skrl/agents/torch/sac/sac_rnn.py | 119 ++++++++----- skrl/agents/torch/td3/td3.py | 104 ++++++----- skrl/agents/torch/td3/td3_rnn.py | 108 +++++++----- skrl/utils/spaces/torch/spaces.py | 18 +- tests/jax/test_jax_agent_a2c.py | 185 ++++++++++++++++++++ tests/jax/test_jax_agent_cem.py | 142 +++++++++++++++ tests/jax/test_jax_agent_ddpg.py | 199 +++++++++++++++++++++ tests/jax/test_jax_agent_ddqn.py | 174 +++++++++++++++++++ tests/jax/test_jax_agent_dqn.py | 174 +++++++++++++++++++ tests/jax/test_jax_agent_ppo.py | 207 ++++++++++++++++++++++ tests/jax/test_jax_agent_rpo.py | 199 +++++++++++++++++++++ tests/jax/test_jax_agent_sac.py | 195 +++++++++++++++++++++ tests/jax/test_jax_agent_td3.py | 230 +++++++++++++++++++++++++ tests/torch/test_torch_agent_a2c.py | 216 +++++++++++++++++++++++ tests/torch/test_torch_agent_amp.py | 248 +++++++++++++++++++++++++++ tests/torch/test_torch_agent_cem.py | 142 +++++++++++++++ tests/torch/test_torch_agent_ddpg.py | 199 +++++++++++++++++++++ tests/torch/test_torch_agent_ddqn.py | 174 +++++++++++++++++++ tests/torch/test_torch_agent_dqn.py | 174 +++++++++++++++++++ tests/torch/test_torch_agent_ppo.py | 238 +++++++++++++++++++++++++ tests/torch/test_torch_agent_rpo.py | 229 +++++++++++++++++++++++++ tests/torch/test_torch_agent_sac.py | 195 +++++++++++++++++++++ tests/torch/test_torch_agent_td3.py | 230 +++++++++++++++++++++++++ tests/torch/test_torch_agent_trpo.py | 199 +++++++++++++++++++++ tests/utils.py | 104 +++-------- 50 files changed, 4917 insertions(+), 648 deletions(-) create mode 100644 tests/jax/test_jax_agent_a2c.py create mode 100644 tests/jax/test_jax_agent_cem.py create mode 100644 tests/jax/test_jax_agent_ddpg.py create mode 100644 tests/jax/test_jax_agent_ddqn.py create mode 100644 tests/jax/test_jax_agent_dqn.py create mode 100644 tests/jax/test_jax_agent_ppo.py create mode 100644 tests/jax/test_jax_agent_rpo.py create mode 100644 tests/jax/test_jax_agent_sac.py create mode 100644 tests/jax/test_jax_agent_td3.py create mode 100644 tests/torch/test_torch_agent_a2c.py create mode 100644 tests/torch/test_torch_agent_amp.py create mode 100644 tests/torch/test_torch_agent_cem.py create mode 100644 tests/torch/test_torch_agent_ddpg.py create mode 100644 tests/torch/test_torch_agent_ddqn.py create mode 100644 tests/torch/test_torch_agent_dqn.py create mode 100644 tests/torch/test_torch_agent_ppo.py create mode 100644 tests/torch/test_torch_agent_rpo.py create mode 100644 tests/torch/test_torch_agent_sac.py create mode 100644 tests/torch/test_torch_agent_td3.py create mode 100644 tests/torch/test_torch_agent_trpo.py diff --git a/CHANGELOG.md b/CHANGELOG.md index aa715359..100cffb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - `parse_device` static method in ML framework configuration (used in library components to set up the device) - Model instantiator support for different shared model structures in PyTorch - Support for other model types than Gaussian and Deterministic in runners +- Support for automatic mixed precision training in PyTorch ### Changed - Call agent's `pre_interaction` method during evaluation diff --git a/docs/source/api/agents/a2c.rst b/docs/source/api/agents/a2c.rst index 4dccdc03..d2dc53ab 100644 --- a/docs/source/api/agents/a2c.rst +++ b/docs/source/api/agents/a2c.rst @@ -232,6 +232,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Mixed precision + - Automatic mixed precision + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` * - Distributed - Single Program Multi Data (SPMD) multi-GPU - .. centered:: :math:`\blacksquare` diff --git a/docs/source/api/agents/amp.rst b/docs/source/api/agents/amp.rst index 6e4c8c0f..a8942dbc 100644 --- a/docs/source/api/agents/amp.rst +++ b/docs/source/api/agents/amp.rst @@ -237,6 +237,10 @@ Support for advanced features is described in the next table - \- - .. centered:: :math:`\square` - .. centered:: :math:`\square` + * - Mixed precision + - Automatic mixed precision + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` * - Distributed - Single Program Multi Data (SPMD) multi-GPU - .. centered:: :math:`\blacksquare` diff --git a/docs/source/api/agents/cem.rst b/docs/source/api/agents/cem.rst index e419a0c6..d4b9dfb5 100644 --- a/docs/source/api/agents/cem.rst +++ b/docs/source/api/agents/cem.rst @@ -175,6 +175,10 @@ Support for advanced features is described in the next table - \- - .. centered:: :math:`\square` - .. centered:: :math:`\square` + * - Mixed precision + - Automatic mixed precision + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` * - Distributed - \- - .. centered:: :math:`\square` diff --git a/docs/source/api/agents/ddpg.rst b/docs/source/api/agents/ddpg.rst index 3ac4856c..92c26182 100644 --- a/docs/source/api/agents/ddpg.rst +++ b/docs/source/api/agents/ddpg.rst @@ -236,6 +236,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Mixed precision + - Automatic mixed precision + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` * - Distributed - Single Program Multi Data (SPMD) multi-GPU - .. centered:: :math:`\blacksquare` diff --git a/docs/source/api/agents/ddqn.rst b/docs/source/api/agents/ddqn.rst index c43add22..d2a499a9 100644 --- a/docs/source/api/agents/ddqn.rst +++ b/docs/source/api/agents/ddqn.rst @@ -184,6 +184,10 @@ Support for advanced features is described in the next table - \- - .. centered:: :math:`\square` - .. centered:: :math:`\square` + * - Mixed precision + - Automatic mixed precision + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` * - Distributed - Single Program Multi Data (SPMD) multi-GPU - .. centered:: :math:`\blacksquare` diff --git a/docs/source/api/agents/dqn.rst b/docs/source/api/agents/dqn.rst index c1f769ca..51731ab6 100644 --- a/docs/source/api/agents/dqn.rst +++ b/docs/source/api/agents/dqn.rst @@ -184,6 +184,10 @@ Support for advanced features is described in the next table - \- - .. centered:: :math:`\square` - .. centered:: :math:`\square` + * - Mixed precision + - Automatic mixed precision + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` * - Distributed - Single Program Multi Data (SPMD) multi-GPU - .. centered:: :math:`\blacksquare` diff --git a/docs/source/api/agents/ppo.rst b/docs/source/api/agents/ppo.rst index 4e824acd..a7284c23 100644 --- a/docs/source/api/agents/ppo.rst +++ b/docs/source/api/agents/ppo.rst @@ -248,6 +248,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Mixed precision + - Automatic mixed precision + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` * - Distributed - Single Program Multi Data (SPMD) multi-GPU - .. centered:: :math:`\blacksquare` diff --git a/docs/source/api/agents/rpo.rst b/docs/source/api/agents/rpo.rst index faabf233..efaf526c 100644 --- a/docs/source/api/agents/rpo.rst +++ b/docs/source/api/agents/rpo.rst @@ -285,6 +285,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Mixed precision + - Automatic mixed precision + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` * - Distributed - Single Program Multi Data (SPMD) multi-GPU - .. centered:: :math:`\blacksquare` diff --git a/docs/source/api/agents/sac.rst b/docs/source/api/agents/sac.rst index 369e4fb9..4620d35b 100644 --- a/docs/source/api/agents/sac.rst +++ b/docs/source/api/agents/sac.rst @@ -244,6 +244,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Mixed precision + - Automatic mixed precision + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` * - Distributed - Single Program Multi Data (SPMD) multi-GPU - .. centered:: :math:`\blacksquare` diff --git a/docs/source/api/agents/td3.rst b/docs/source/api/agents/td3.rst index f9da210b..c68f71a1 100644 --- a/docs/source/api/agents/td3.rst +++ b/docs/source/api/agents/td3.rst @@ -258,6 +258,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Mixed precision + - Automatic mixed precision + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` * - Distributed - Single Program Multi Data (SPMD) multi-GPU - .. centered:: :math:`\blacksquare` diff --git a/docs/source/api/agents/trpo.rst b/docs/source/api/agents/trpo.rst index 85313c88..3f271da0 100644 --- a/docs/source/api/agents/trpo.rst +++ b/docs/source/api/agents/trpo.rst @@ -282,6 +282,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Mixed precision + - \- + - .. centered:: :math:`\square` + - .. centered:: :math:`\square` * - Distributed - Single Program Multi Data (SPMD) multi-GPU - .. centered:: :math:`\blacksquare` diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py index 588767d4..e615f059 100644 --- a/skrl/agents/torch/a2c/a2c.py +++ b/skrl/agents/torch/a2c/a2c.py @@ -43,6 +43,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward "time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation) + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -142,6 +144,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizer and learning rate scheduler if self.policy is not None and self.value is not None: if self.policy is self.value: @@ -211,8 +219,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") # sample stochastic actions - actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") - self._current_log_prob = log_prob + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") + self._current_log_prob = log_prob return actions, log_prob, outputs @@ -261,8 +270,9 @@ def record_transition( rewards = self._rewards_shaper(rewards, timestep, timesteps) # compute values - values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value") - values = self._value_preprocessor(values, inverse=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value") + values = self._value_preprocessor(values, inverse=True) # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: @@ -375,13 +385,13 @@ def compute_gae( return returns, advantages # compute returns and advantages - with torch.no_grad(): + with torch.no_grad(), torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): self.value.train(False) last_values, _, _ = self.value.act( {"states": self._state_preprocessor(self._current_next_states.float())}, role="value" ) self.value.train(True) - last_values = self._value_preprocessor(last_values, inverse=True) + last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") returns, advantages = compute_gae( @@ -409,49 +419,56 @@ def compute_gae( # mini-batches loop for sampled_states, sampled_actions, sampled_log_prob, sampled_returns, sampled_advantages in sampled_batches: - sampled_states = self._state_preprocessor(sampled_states, train=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - _, next_log_prob, _ = self.policy.act( - {"states": sampled_states, "taken_actions": sampled_actions}, role="policy" - ) + sampled_states = self._state_preprocessor(sampled_states, train=True) - # compute approximate KL divergence for KLAdaptive learning rate scheduler - if self._learning_rate_scheduler: - if isinstance(self.scheduler, KLAdaptiveLR): - with torch.no_grad(): - ratio = next_log_prob - sampled_log_prob - kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() - kl_divergences.append(kl_divergence) + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="policy" + ) - # compute entropy loss - if self._entropy_loss_scale: - entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() - else: - entropy_loss = 0 + # compute approximate KL divergence for KLAdaptive learning rate scheduler + if self._learning_rate_scheduler: + if isinstance(self.scheduler, KLAdaptiveLR): + with torch.no_grad(): + ratio = next_log_prob - sampled_log_prob + kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() + kl_divergences.append(kl_divergence) + + # compute entropy loss + if self._entropy_loss_scale: + entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() + else: + entropy_loss = 0 - # compute policy loss - policy_loss = -(sampled_advantages * next_log_prob).mean() + # compute policy loss + policy_loss = -(sampled_advantages * next_log_prob).mean() - # compute value loss - predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value") + # compute value loss + predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value") - value_loss = F.mse_loss(sampled_returns, predicted_values) + value_loss = F.mse_loss(sampled_returns, predicted_values) # optimization step self.optimizer.zero_grad() - (policy_loss + entropy_loss + value_loss).backward() + self.scaler.scale(policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() if self.policy is not self.value: self.value.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.optimizer) if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: nn.utils.clip_grad_norm_( itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip ) - self.optimizer.step() + + self.scaler.step(self.optimizer) + self.scaler.update() # update cumulative losses cumulative_policy_loss += policy_loss.item() diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py index b241baf4..b01d45b0 100644 --- a/skrl/agents/torch/a2c/a2c_rnn.py +++ b/skrl/agents/torch/a2c/a2c_rnn.py @@ -43,6 +43,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward "time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation) + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -142,6 +144,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizer and learning rate scheduler if self.policy is not None and self.value is not None: if self.policy is self.value: @@ -248,8 +256,11 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states), **rnn}, role="policy") # sample stochastic actions - actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states), **rnn}, role="policy") - self._current_log_prob = log_prob + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, log_prob, outputs = self.policy.act( + {"states": self._state_preprocessor(states), **rnn}, role="policy" + ) + self._current_log_prob = log_prob if self._rnn: self._rnn_final_states["policy"] = outputs.get("rnn", []) @@ -301,9 +312,10 @@ def record_transition( rewards = self._rewards_shaper(rewards, timestep, timesteps) # compute values - rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} - values, _, outputs = self.value.act({"states": self._state_preprocessor(states), **rnn}, role="value") - values = self._value_preprocessor(values, inverse=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} + values, _, outputs = self.value.act({"states": self._state_preprocessor(states), **rnn}, role="value") + values = self._value_preprocessor(values, inverse=True) # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: @@ -446,14 +458,14 @@ def compute_gae( return returns, advantages # compute returns and advantages - with torch.no_grad(): + with torch.no_grad(), torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): self.value.train(False) rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} last_values, _, _ = self.value.act( {"states": self._state_preprocessor(self._current_next_states.float()), **rnn}, role="value" ) self.value.train(True) - last_values = self._value_preprocessor(last_values, inverse=True) + last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") returns, advantages = compute_gae( @@ -523,48 +535,55 @@ def compute_gae( "terminated": sampled_dones, } - sampled_states = self._state_preprocessor(sampled_states, train=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - _, next_log_prob, _ = self.policy.act( - {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy" - ) + sampled_states = self._state_preprocessor(sampled_states, train=True) - # compute approximate KL divergence for KLAdaptive learning rate scheduler - if isinstance(self.scheduler, KLAdaptiveLR): - with torch.no_grad(): - ratio = next_log_prob - sampled_log_prob - kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() - kl_divergences.append(kl_divergence) + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy" + ) - # compute entropy loss - if self._entropy_loss_scale: - entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() - else: - entropy_loss = 0 + # compute approximate KL divergence for KLAdaptive learning rate scheduler + if isinstance(self.scheduler, KLAdaptiveLR): + with torch.no_grad(): + ratio = next_log_prob - sampled_log_prob + kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() + kl_divergences.append(kl_divergence) + + # compute entropy loss + if self._entropy_loss_scale: + entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() + else: + entropy_loss = 0 - # compute policy loss - policy_loss = -(sampled_advantages * next_log_prob).mean() + # compute policy loss + policy_loss = -(sampled_advantages * next_log_prob).mean() - # compute value loss - predicted_values, _, _ = self.value.act({"states": sampled_states, **rnn_value}, role="value") + # compute value loss + predicted_values, _, _ = self.value.act({"states": sampled_states, **rnn_value}, role="value") - value_loss = F.mse_loss(sampled_returns, predicted_values) + value_loss = F.mse_loss(sampled_returns, predicted_values) # optimization step self.optimizer.zero_grad() - (policy_loss + entropy_loss + value_loss).backward() + self.scaler.scale(policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() if self.policy is not self.value: self.value.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.optimizer) if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: nn.utils.clip_grad_norm_( itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip ) - self.optimizer.step() + + self.scaler.step(self.optimizer) + self.scaler.update() # update cumulative losses cumulative_policy_loss += policy_loss.item() diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index de4ffda4..e5ad4570 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -60,6 +60,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward "time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation) + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -204,6 +206,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizer and learning rate scheduler if self.policy is not None and self.value is not None and self.discriminator is not None: self.optimizer = torch.optim.Adam( @@ -308,8 +316,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": states}, role="policy") # sample stochastic actions - actions, log_prob, outputs = self.policy.act({"states": states}, role="policy") - self._current_log_prob = log_prob + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, log_prob, outputs = self.policy.act({"states": states}, role="policy") + self._current_log_prob = log_prob return actions, log_prob, outputs @@ -361,18 +370,20 @@ def record_transition( if self._rewards_shaper is not None: rewards = self._rewards_shaper(rewards, timestep, timesteps) - with torch.no_grad(): + # compute values + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value") - values = self._value_preprocessor(values, inverse=True) + values = self._value_preprocessor(values, inverse=True) # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: rewards += self._discount_factor * values * truncated - with torch.no_grad(): + # compute next values + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): next_values, _, _ = self.value.act({"states": self._state_preprocessor(next_states)}, role="value") - next_values = self._value_preprocessor(next_values, inverse=True) - next_values *= infos["terminate"].view(-1, 1).logical_not() + next_values = self._value_preprocessor(next_values, inverse=True) + next_values *= infos["terminate"].view(-1, 1).logical_not() self.memory.add_samples( states=states, @@ -490,7 +501,7 @@ def compute_gae( rewards = self.memory.get_tensor_by_name("rewards") amp_states = self.memory.get_tensor_by_name("amp_states") - with torch.no_grad(): + with torch.no_grad(), torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): amp_logits, _, _ = self.discriminator.act( {"states": self._amp_state_preprocessor(amp_states)}, role="discriminator" ) @@ -554,120 +565,127 @@ def compute_gae( _, ) in enumerate(sampled_batches): - sampled_states = self._state_preprocessor(sampled_states, train=True) - - _, next_log_prob, _ = self.policy.act( - {"states": sampled_states, "taken_actions": sampled_actions}, role="policy" - ) - - # compute entropy loss - if self._entropy_loss_scale: - entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() - else: - entropy_loss = 0 - - # compute policy loss - ratio = torch.exp(next_log_prob - sampled_log_prob) - surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip( - ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip - ) - - policy_loss = -torch.min(surrogate, surrogate_clipped).mean() + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - # compute value loss - predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value") + sampled_states = self._state_preprocessor(sampled_states, train=True) - if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip( - predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="policy" ) - value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) - # compute discriminator loss - if self._discriminator_batch_size: - sampled_amp_states = self._amp_state_preprocessor( - sampled_amp_states[0 : self._discriminator_batch_size], train=True + # compute entropy loss + if self._entropy_loss_scale: + entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() + else: + entropy_loss = 0 + + # compute policy loss + ratio = torch.exp(next_log_prob - sampled_log_prob) + surrogate = sampled_advantages * ratio + surrogate_clipped = sampled_advantages * torch.clip( + ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip ) - sampled_amp_replay_states = self._amp_state_preprocessor( - sampled_replay_batches[batch_index][0][0 : self._discriminator_batch_size], train=True - ) - sampled_amp_motion_states = self._amp_state_preprocessor( - sampled_motion_batches[batch_index][0][0 : self._discriminator_batch_size], train=True - ) - else: - sampled_amp_states = self._amp_state_preprocessor(sampled_amp_states, train=True) - sampled_amp_replay_states = self._amp_state_preprocessor( - sampled_replay_batches[batch_index][0], train=True + + policy_loss = -torch.min(surrogate, surrogate_clipped).mean() + + # compute value loss + predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value") + + if self._clip_predicted_values: + predicted_values = sampled_values + torch.clip( + predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + ) + value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) + + # compute discriminator loss + if self._discriminator_batch_size: + sampled_amp_states = self._amp_state_preprocessor( + sampled_amp_states[0 : self._discriminator_batch_size], train=True + ) + sampled_amp_replay_states = self._amp_state_preprocessor( + sampled_replay_batches[batch_index][0][0 : self._discriminator_batch_size], train=True + ) + sampled_amp_motion_states = self._amp_state_preprocessor( + sampled_motion_batches[batch_index][0][0 : self._discriminator_batch_size], train=True + ) + else: + sampled_amp_states = self._amp_state_preprocessor(sampled_amp_states, train=True) + sampled_amp_replay_states = self._amp_state_preprocessor( + sampled_replay_batches[batch_index][0], train=True + ) + sampled_amp_motion_states = self._amp_state_preprocessor( + sampled_motion_batches[batch_index][0], train=True + ) + + sampled_amp_motion_states.requires_grad_(True) + amp_logits, _, _ = self.discriminator.act({"states": sampled_amp_states}, role="discriminator") + amp_replay_logits, _, _ = self.discriminator.act( + {"states": sampled_amp_replay_states}, role="discriminator" ) - sampled_amp_motion_states = self._amp_state_preprocessor( - sampled_motion_batches[batch_index][0], train=True + amp_motion_logits, _, _ = self.discriminator.act( + {"states": sampled_amp_motion_states}, role="discriminator" ) - sampled_amp_motion_states.requires_grad_(True) - amp_logits, _, _ = self.discriminator.act({"states": sampled_amp_states}, role="discriminator") - amp_replay_logits, _, _ = self.discriminator.act( - {"states": sampled_amp_replay_states}, role="discriminator" - ) - amp_motion_logits, _, _ = self.discriminator.act( - {"states": sampled_amp_motion_states}, role="discriminator" - ) - - amp_cat_logits = torch.cat([amp_logits, amp_replay_logits], dim=0) + amp_cat_logits = torch.cat([amp_logits, amp_replay_logits], dim=0) - # discriminator prediction loss - discriminator_loss = 0.5 * ( - nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits)) - + torch.nn.BCEWithLogitsLoss()(amp_motion_logits, torch.ones_like(amp_motion_logits)) - ) - - # discriminator logit regularization - if self._discriminator_logit_regularization_scale: - logit_weights = torch.flatten(list(self.discriminator.modules())[-1].weight) - discriminator_loss += self._discriminator_logit_regularization_scale * torch.sum( - torch.square(logit_weights) + # discriminator prediction loss + discriminator_loss = 0.5 * ( + nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits)) + + torch.nn.BCEWithLogitsLoss()(amp_motion_logits, torch.ones_like(amp_motion_logits)) ) - # discriminator gradient penalty - if self._discriminator_gradient_penalty_scale: - amp_motion_gradient = torch.autograd.grad( - amp_motion_logits, - sampled_amp_motion_states, - grad_outputs=torch.ones_like(amp_motion_logits), - create_graph=True, - retain_graph=True, - only_inputs=True, - ) - gradient_penalty = torch.sum(torch.square(amp_motion_gradient[0]), dim=-1).mean() - discriminator_loss += self._discriminator_gradient_penalty_scale * gradient_penalty - - # discriminator weight decay - if self._discriminator_weight_decay_scale: - weights = [ - torch.flatten(module.weight) - for module in self.discriminator.modules() - if isinstance(module, torch.nn.Linear) - ] - weight_decay = torch.sum(torch.square(torch.cat(weights, dim=-1))) - discriminator_loss += self._discriminator_weight_decay_scale * weight_decay - - discriminator_loss *= self._discriminator_loss_scale + # discriminator logit regularization + if self._discriminator_logit_regularization_scale: + logit_weights = torch.flatten(list(self.discriminator.modules())[-1].weight) + discriminator_loss += self._discriminator_logit_regularization_scale * torch.sum( + torch.square(logit_weights) + ) + + # discriminator gradient penalty + if self._discriminator_gradient_penalty_scale: + amp_motion_gradient = torch.autograd.grad( + amp_motion_logits, + sampled_amp_motion_states, + grad_outputs=torch.ones_like(amp_motion_logits), + create_graph=True, + retain_graph=True, + only_inputs=True, + ) + gradient_penalty = torch.sum(torch.square(amp_motion_gradient[0]), dim=-1).mean() + discriminator_loss += self._discriminator_gradient_penalty_scale * gradient_penalty + + # discriminator weight decay + if self._discriminator_weight_decay_scale: + weights = [ + torch.flatten(module.weight) + for module in self.discriminator.modules() + if isinstance(module, torch.nn.Linear) + ] + weight_decay = torch.sum(torch.square(torch.cat(weights, dim=-1))) + discriminator_loss += self._discriminator_weight_decay_scale * weight_decay + + discriminator_loss *= self._discriminator_loss_scale # optimization step self.optimizer.zero_grad() - (policy_loss + entropy_loss + value_loss + discriminator_loss).backward() + self.scaler.scale(policy_loss + entropy_loss + value_loss + discriminator_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() self.value.reduce_parameters() self.discriminator.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_( itertools.chain( self.policy.parameters(), self.value.parameters(), self.discriminator.parameters() ), self._grad_norm_clip, ) - self.optimizer.step() + + self.scaler.step(self.optimizer) + self.scaler.update() # update cumulative losses cumulative_policy_loss += policy_loss.item() diff --git a/skrl/agents/torch/cem/cem.py b/skrl/agents/torch/cem/cem.py index 4daf2ee7..5f5b3492 100644 --- a/skrl/agents/torch/cem/cem.py +++ b/skrl/agents/torch/cem/cem.py @@ -32,6 +32,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -114,8 +116,14 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] + self._mixed_precision = self.cfg["mixed_precision"] + self._episode_tracking = [] + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizer and learning rate scheduler if self.policy is not None: self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._learning_rate) @@ -168,7 +176,8 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": states}, role="policy") # sample stochastic actions - return self.policy.act({"states": states}, role="policy") + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + return self.policy.act({"states": states}, role="policy") def record_transition( self, @@ -306,17 +315,21 @@ def _update(self, timestep: int, timesteps: int) -> None: elite_states = torch.cat([sampled_states[limits[i][0] : limits[i][1]] for i in indexes[:, 0]], dim=0) elite_actions = torch.cat([sampled_actions[limits[i][0] : limits[i][1]] for i in indexes[:, 0]], dim=0) - # compute scores for the elite states - _, _, outputs = self.policy.act({"states": elite_states}, role="policy") - scores = outputs["net_output"] + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - # compute policy loss - policy_loss = F.cross_entropy(scores, elite_actions.view(-1)) + # compute scores for the elite states + _, _, outputs = self.policy.act({"states": elite_states}, role="policy") + scores = outputs["net_output"] + + # compute policy loss + policy_loss = F.cross_entropy(scores, elite_actions.view(-1)) # optimization step self.optimizer.zero_grad() - policy_loss.backward() - self.optimizer.step() + self.scaler.scale(policy_loss).backward() + + self.scaler.step(self.optimizer) + self.scaler.update() # update learning rate if self._learning_rate_scheduler: diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py index 25983eba..52fc37c5 100644 --- a/skrl/agents/torch/ddpg/ddpg.py +++ b/skrl/agents/torch/ddpg/ddpg.py @@ -44,6 +44,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -157,6 +159,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizers and learning rate schedulers if self.policy is not None and self.critic is not None: self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) @@ -217,7 +225,8 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") # sample deterministic actions - actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") # add exloration noise if self._exploration_noise is not None: @@ -357,48 +366,65 @@ def _update(self, timestep: int, timesteps: int) -> None: names=self._tensors_names, batch_size=self._batch_size )[0] - sampled_states = self._state_preprocessor(sampled_states, train=True) - sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - # compute target values - with torch.no_grad(): - next_actions, _, _ = self.target_policy.act({"states": sampled_next_states}, role="target_policy") + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) - target_q_values, _, _ = self.target_critic.act( - {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic" - ) - target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values + # compute target values + with torch.no_grad(): + next_actions, _, _ = self.target_policy.act({"states": sampled_next_states}, role="target_policy") - # compute critic loss - critic_values, _, _ = self.critic.act( - {"states": sampled_states, "taken_actions": sampled_actions}, role="critic" - ) + target_q_values, _, _ = self.target_critic.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic" + ) + target_values = ( + sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values + ) + + # compute critic loss + critic_values, _, _ = self.critic.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="critic" + ) - critic_loss = F.mse_loss(critic_values, target_values) + critic_loss = F.mse_loss(critic_values, target_values) # optimization step (critic) self.critic_optimizer.zero_grad() - critic_loss.backward() + self.scaler.scale(critic_loss).backward() + if config.torch.is_distributed: self.critic.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.critic_optimizer) nn.utils.clip_grad_norm_(self.critic.parameters(), self._grad_norm_clip) - self.critic_optimizer.step() - # compute policy (actor) loss - actions, _, _ = self.policy.act({"states": sampled_states}, role="policy") - critic_values, _, _ = self.critic.act({"states": sampled_states, "taken_actions": actions}, role="critic") + self.scaler.step(self.critic_optimizer) + + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute policy (actor) loss + actions, _, _ = self.policy.act({"states": sampled_states}, role="policy") + critic_values, _, _ = self.critic.act( + {"states": sampled_states, "taken_actions": actions}, role="critic" + ) - policy_loss = -critic_values.mean() + policy_loss = -critic_values.mean() # optimization step (policy) self.policy_optimizer.zero_grad() - policy_loss.backward() + self.scaler.scale(policy_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.policy_optimizer) nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) - self.policy_optimizer.step() + + self.scaler.step(self.policy_optimizer) + + self.scaler.update() # called once, after optimizers have been stepped # update target networks self.target_policy.update_parameters(self.policy, polyak=self._polyak) diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py index 89ddebca..d032527d 100644 --- a/skrl/agents/torch/ddpg/ddpg_rnn.py +++ b/skrl/agents/torch/ddpg/ddpg_rnn.py @@ -44,6 +44,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -157,6 +159,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizers and learning rate schedulers if self.policy is not None and self.critic is not None: self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) @@ -238,7 +246,8 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states), **rnn}, role="policy") # sample deterministic actions - actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states), **rnn}, role="policy") + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states), **rnn}, role="policy") if self._rnn: self._rnn_final_states["policy"] = outputs.get("rnn", []) @@ -407,52 +416,68 @@ def _update(self, timestep: int, timesteps: int) -> None: )[0] rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} - sampled_states = self._state_preprocessor(sampled_states, train=True) - sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - # compute target values - with torch.no_grad(): - next_actions, _, _ = self.target_policy.act( - {"states": sampled_next_states, **rnn_policy}, role="target_policy" - ) + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) - target_q_values, _, _ = self.target_critic.act( - {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic" - ) - target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values + # compute target values + with torch.no_grad(): + next_actions, _, _ = self.target_policy.act( + {"states": sampled_next_states, **rnn_policy}, role="target_policy" + ) - # compute critic loss - critic_values, _, _ = self.critic.act( - {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic" - ) + target_q_values, _, _ = self.target_critic.act( + {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, + role="target_critic", + ) + target_values = ( + sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values + ) + + # compute critic loss + critic_values, _, _ = self.critic.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic" + ) - critic_loss = F.mse_loss(critic_values, target_values) + critic_loss = F.mse_loss(critic_values, target_values) # optimization step (critic) self.critic_optimizer.zero_grad() - critic_loss.backward() + self.scaler.scale(critic_loss).backward() + if config.torch.is_distributed: self.critic.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.critic_optimizer) nn.utils.clip_grad_norm_(self.critic.parameters(), self._grad_norm_clip) - self.critic_optimizer.step() - # compute policy (actor) loss - actions, _, _ = self.policy.act({"states": sampled_states, **rnn_policy}, role="policy") - critic_values, _, _ = self.critic.act( - {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic" - ) + self.scaler.step(self.critic_optimizer) + + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute policy (actor) loss + actions, _, _ = self.policy.act({"states": sampled_states, **rnn_policy}, role="policy") + critic_values, _, _ = self.critic.act( + {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic" + ) - policy_loss = -critic_values.mean() + policy_loss = -critic_values.mean() # optimization step (policy) self.policy_optimizer.zero_grad() - policy_loss.backward() + self.scaler.scale(policy_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.policy_optimizer) nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) - self.policy_optimizer.step() + + self.scaler.step(self.policy_optimizer) + + self.scaler.update() # called once, after optimizers have been stepped # update target networks self.target_policy.update_parameters(self.policy, polyak=self._polyak) diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py index fdb73e84..9fda1b1d 100644 --- a/skrl/agents/torch/dqn/ddqn.py +++ b/skrl/agents/torch/dqn/ddqn.py @@ -43,6 +43,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -147,6 +149,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizer and learning rate scheduler if self.q_network is not None: self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=self._learning_rate) @@ -212,9 +220,10 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens indexes = (torch.rand(states.shape[0], device=self.device) >= epsilon).nonzero().view(-1) if indexes.numel(): - actions[indexes] = torch.argmax( - self.q_network.act({"states": states[indexes]}, role="q_network")[0], dim=1, keepdim=True - ) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions[indexes] = torch.argmax( + self.q_network.act({"states": states[indexes]}, role="q_network")[0], dim=1, keepdim=True + ) # record epsilon self.track_data("Exploration / Exploration epsilon", epsilon) @@ -322,37 +331,48 @@ def _update(self, timestep: int, timesteps: int) -> None: names=self.tensors_names, batch_size=self._batch_size )[0] - sampled_states = self._state_preprocessor(sampled_states, train=True) - sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) - - # compute target values - with torch.no_grad(): - next_q_values, _, _ = self.target_q_network.act( - {"states": sampled_next_states}, role="target_q_network" - ) - - target_q_values = torch.gather( - next_q_values, + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + + # compute target values + with torch.no_grad(): + next_q_values, _, _ = self.target_q_network.act( + {"states": sampled_next_states}, role="target_q_network" + ) + + target_q_values = torch.gather( + next_q_values, + dim=1, + index=torch.argmax( + self.q_network.act({"states": sampled_next_states}, role="q_network")[0], + dim=1, + keepdim=True, + ), + ) + target_values = ( + sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values + ) + + # compute Q-network loss + q_values = torch.gather( + self.q_network.act({"states": sampled_states}, role="q_network")[0], dim=1, - index=torch.argmax( - self.q_network.act({"states": sampled_next_states}, role="q_network")[0], dim=1, keepdim=True - ), + index=sampled_actions.long(), ) - target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values - - # compute Q-network loss - q_values = torch.gather( - self.q_network.act({"states": sampled_states}, role="q_network")[0], dim=1, index=sampled_actions.long() - ) - q_network_loss = F.mse_loss(q_values, target_values) + q_network_loss = F.mse_loss(q_values, target_values) # optimize Q-network self.optimizer.zero_grad() - q_network_loss.backward() + self.scaler.scale(q_network_loss).backward() + if config.torch.is_distributed: self.q_network.reduce_parameters() - self.optimizer.step() + + self.scaler.step(self.optimizer) + self.scaler.update() # update target network if not timestep % self._target_update_interval: diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py index 318b56a3..f73e89aa 100644 --- a/skrl/agents/torch/dqn/dqn.py +++ b/skrl/agents/torch/dqn/dqn.py @@ -43,6 +43,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -147,6 +149,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizer and learning rate scheduler if self.q_network is not None: self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=self._learning_rate) @@ -212,9 +220,10 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens indexes = (torch.rand(states.shape[0], device=self.device) >= epsilon).nonzero().view(-1) if indexes.numel(): - actions[indexes] = torch.argmax( - self.q_network.act({"states": states[indexes]}, role="q_network")[0], dim=1, keepdim=True - ) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions[indexes] = torch.argmax( + self.q_network.act({"states": states[indexes]}, role="q_network")[0], dim=1, keepdim=True + ) # record epsilon self.track_data("Exploration / Exploration epsilon", epsilon) @@ -322,31 +331,40 @@ def _update(self, timestep: int, timesteps: int) -> None: names=self.tensors_names, batch_size=self._batch_size )[0] - sampled_states = self._state_preprocessor(sampled_states, train=True) - sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - # compute target values - with torch.no_grad(): - next_q_values, _, _ = self.target_q_network.act( - {"states": sampled_next_states}, role="target_q_network" - ) + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) - target_q_values = torch.max(next_q_values, dim=-1, keepdim=True)[0] - target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values + # compute target values + with torch.no_grad(): + next_q_values, _, _ = self.target_q_network.act( + {"states": sampled_next_states}, role="target_q_network" + ) - # compute Q-network loss - q_values = torch.gather( - self.q_network.act({"states": sampled_states}, role="q_network")[0], dim=1, index=sampled_actions.long() - ) + target_q_values = torch.max(next_q_values, dim=-1, keepdim=True)[0] + target_values = ( + sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values + ) + + # compute Q-network loss + q_values = torch.gather( + self.q_network.act({"states": sampled_states}, role="q_network")[0], + dim=1, + index=sampled_actions.long(), + ) - q_network_loss = F.mse_loss(q_values, target_values) + q_network_loss = F.mse_loss(q_values, target_values) # optimize Q-network self.optimizer.zero_grad() - q_network_loss.backward() + self.scaler.scale(q_network_loss).backward() + if config.torch.is_distributed: self.q_network.reduce_parameters() - self.optimizer.step() + + self.scaler.step(self.optimizer) + self.scaler.update() # update target network if not timestep % self._target_update_interval: diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index 7c15a9ce..597a4ce0 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -50,7 +50,7 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward "time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation) - "mixed_precision": False, # mixed torch.float32 and torch.float16 precision for higher performance + "mixed_precision": False, # enable automatic mixed precision for higher performance "experiment": { "directory": "", # experiment's parent directory @@ -162,7 +162,7 @@ def __init__( # set up automatic mixed precision self._device_type = torch.device(device).type - self._scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) # set up optimizer and learning rate scheduler if self.policy is not None and self.value is not None: @@ -400,13 +400,13 @@ def compute_gae( return returns, advantages # compute returns and advantages - with torch.no_grad(): + with torch.no_grad(), torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): self.value.train(False) last_values, _, _ = self.value.act( {"states": self._state_preprocessor(self._current_next_states.float())}, role="value" ) self.value.train(True) - last_values = self._value_preprocessor(last_values, inverse=True) + last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") returns, advantages = compute_gae( @@ -487,7 +487,7 @@ def compute_gae( # optimization step self.optimizer.zero_grad() - self._scaler.scale(policy_loss + entropy_loss + value_loss).backward() + self.scaler.scale(policy_loss + entropy_loss + value_loss).backward() if config.torch.is_distributed: self.policy.reduce_parameters() @@ -495,7 +495,7 @@ def compute_gae( self.value.reduce_parameters() if self._grad_norm_clip > 0: - self._scaler.unscale_(self.optimizer) + self.scaler.unscale_(self.optimizer) if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: @@ -503,8 +503,8 @@ def compute_gae( itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip ) - self._scaler.step(self.optimizer) - self._scaler.update() + self.scaler.step(self.optimizer) + self.scaler.update() # update cumulative losses cumulative_policy_loss += policy_loss.item() diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py index 9e22be4e..16167ffe 100644 --- a/skrl/agents/torch/ppo/ppo_rnn.py +++ b/skrl/agents/torch/ppo/ppo_rnn.py @@ -50,6 +50,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward "time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation) + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -156,6 +158,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizer and learning rate scheduler if self.policy is not None and self.value is not None: if self.policy is self.value: @@ -263,8 +271,11 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states), **rnn}, role="policy") # sample stochastic actions - actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states), **rnn}, role="policy") - self._current_log_prob = log_prob + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, log_prob, outputs = self.policy.act( + {"states": self._state_preprocessor(states), **rnn}, role="policy" + ) + self._current_log_prob = log_prob if self._rnn: self._rnn_final_states["policy"] = outputs.get("rnn", []) @@ -316,9 +327,10 @@ def record_transition( rewards = self._rewards_shaper(rewards, timestep, timesteps) # compute values - rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} - values, _, outputs = self.value.act({"states": self._state_preprocessor(states), **rnn}, role="value") - values = self._value_preprocessor(values, inverse=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} + values, _, outputs = self.value.act({"states": self._state_preprocessor(states), **rnn}, role="value") + values = self._value_preprocessor(values, inverse=True) # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: @@ -461,14 +473,14 @@ def compute_gae( return returns, advantages # compute returns and advantages - with torch.no_grad(): + with torch.no_grad(), torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): self.value.train(False) rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} last_values, _, _ = self.value.act( {"states": self._state_preprocessor(self._current_next_states.float()), **rnn}, role="value" ) self.value.train(True) - last_values = self._value_preprocessor(last_values, inverse=True) + last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") returns, advantages = compute_gae( @@ -541,61 +553,68 @@ def compute_gae( "terminated": sampled_dones, } - sampled_states = self._state_preprocessor(sampled_states, train=not epoch) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - _, next_log_prob, _ = self.policy.act( - {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy" - ) + sampled_states = self._state_preprocessor(sampled_states, train=not epoch) - # compute approximate KL divergence - with torch.no_grad(): - ratio = next_log_prob - sampled_log_prob - kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() - kl_divergences.append(kl_divergence) + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="policy" + ) - # early stopping with KL divergence - if self._kl_threshold and kl_divergence > self._kl_threshold: - break + # compute approximate KL divergence + with torch.no_grad(): + ratio = next_log_prob - sampled_log_prob + kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() + kl_divergences.append(kl_divergence) - # compute entropy loss - if self._entropy_loss_scale: - entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() - else: - entropy_loss = 0 + # early stopping with KL divergence + if self._kl_threshold and kl_divergence > self._kl_threshold: + break - # compute policy loss - ratio = torch.exp(next_log_prob - sampled_log_prob) - surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip( - ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip - ) + # compute entropy loss + if self._entropy_loss_scale: + entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() + else: + entropy_loss = 0 - policy_loss = -torch.min(surrogate, surrogate_clipped).mean() + # compute policy loss + ratio = torch.exp(next_log_prob - sampled_log_prob) + surrogate = sampled_advantages * ratio + surrogate_clipped = sampled_advantages * torch.clip( + ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip + ) - # compute value loss - predicted_values, _, _ = self.value.act({"states": sampled_states, **rnn_value}, role="value") + policy_loss = -torch.min(surrogate, surrogate_clipped).mean() - if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip( - predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip - ) - value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) + # compute value loss + predicted_values, _, _ = self.value.act({"states": sampled_states, **rnn_value}, role="value") + + if self._clip_predicted_values: + predicted_values = sampled_values + torch.clip( + predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + ) + value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) # optimization step self.optimizer.zero_grad() - (policy_loss + entropy_loss + value_loss).backward() + self.scaler.scale(policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() if self.policy is not self.value: self.value.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.optimizer) if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: nn.utils.clip_grad_norm_( itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip ) - self.optimizer.step() + + self.scaler.step(self.optimizer) + self.scaler.update() # update cumulative losses cumulative_policy_loss += policy_loss.item() diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py index e91ce293..4304896a 100644 --- a/skrl/agents/torch/rpo/rpo.py +++ b/skrl/agents/torch/rpo/rpo.py @@ -51,6 +51,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward "time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation) + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -158,6 +160,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizer and learning rate scheduler if self.policy is not None and self.value is not None: if self.policy is self.value: @@ -228,10 +236,11 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") # sample stochastic actions - actions, log_prob, outputs = self.policy.act( - {"states": self._state_preprocessor(states), "alpha": self._alpha}, role="policy" - ) - self._current_log_prob = log_prob + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, log_prob, outputs = self.policy.act( + {"states": self._state_preprocessor(states), "alpha": self._alpha}, role="policy" + ) + self._current_log_prob = log_prob return actions, log_prob, outputs @@ -280,10 +289,11 @@ def record_transition( rewards = self._rewards_shaper(rewards, timestep, timesteps) # compute values - values, _, _ = self.value.act( - {"states": self._state_preprocessor(states), "alpha": self._alpha}, role="value" - ) - values = self._value_preprocessor(values, inverse=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + values, _, _ = self.value.act( + {"states": self._state_preprocessor(states), "alpha": self._alpha}, role="value" + ) + values = self._value_preprocessor(values, inverse=True) # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: @@ -396,14 +406,14 @@ def compute_gae( return returns, advantages # compute returns and advantages - with torch.no_grad(): + with torch.no_grad(), torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): self.value.train(False) last_values, _, _ = self.value.act( {"states": self._state_preprocessor(self._current_next_states.float()), "alpha": self._alpha}, role="value", ) self.value.train(True) - last_values = self._value_preprocessor(last_values, inverse=True) + last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") returns, advantages = compute_gae( @@ -440,61 +450,71 @@ def compute_gae( sampled_advantages, ) in sampled_batches: - sampled_states = self._state_preprocessor(sampled_states, train=not epoch) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - _, next_log_prob, _ = self.policy.act( - {"states": sampled_states, "taken_actions": sampled_actions, "alpha": self._alpha}, role="policy" - ) + sampled_states = self._state_preprocessor(sampled_states, train=not epoch) - # compute approximate KL divergence - with torch.no_grad(): - ratio = next_log_prob - sampled_log_prob - kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() - kl_divergences.append(kl_divergence) + _, next_log_prob, _ = self.policy.act( + {"states": sampled_states, "taken_actions": sampled_actions, "alpha": self._alpha}, + role="policy", + ) - # early stopping with KL divergence - if self._kl_threshold and kl_divergence > self._kl_threshold: - break + # compute approximate KL divergence + with torch.no_grad(): + ratio = next_log_prob - sampled_log_prob + kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() + kl_divergences.append(kl_divergence) - # compute entropy loss - if self._entropy_loss_scale: - entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() - else: - entropy_loss = 0 + # early stopping with KL divergence + if self._kl_threshold and kl_divergence > self._kl_threshold: + break - # compute policy loss - ratio = torch.exp(next_log_prob - sampled_log_prob) - surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip( - ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip - ) + # compute entropy loss + if self._entropy_loss_scale: + entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() + else: + entropy_loss = 0 - policy_loss = -torch.min(surrogate, surrogate_clipped).mean() + # compute policy loss + ratio = torch.exp(next_log_prob - sampled_log_prob) + surrogate = sampled_advantages * ratio + surrogate_clipped = sampled_advantages * torch.clip( + ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip + ) - # compute value loss - predicted_values, _, _ = self.value.act({"states": sampled_states, "alpha": self._alpha}, role="value") + policy_loss = -torch.min(surrogate, surrogate_clipped).mean() - if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip( - predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + # compute value loss + predicted_values, _, _ = self.value.act( + {"states": sampled_states, "alpha": self._alpha}, role="value" ) - value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) + + if self._clip_predicted_values: + predicted_values = sampled_values + torch.clip( + predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + ) + value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) # optimization step self.optimizer.zero_grad() - (policy_loss + entropy_loss + value_loss).backward() + self.scaler.scale(policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() if self.policy is not self.value: self.value.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.optimizer) if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: nn.utils.clip_grad_norm_( itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip ) - self.optimizer.step() + + self.scaler.step(self.optimizer) + self.scaler.update() # update cumulative losses cumulative_policy_loss += policy_loss.item() diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py index 550534d4..d6c7b393 100644 --- a/skrl/agents/torch/rpo/rpo_rnn.py +++ b/skrl/agents/torch/rpo/rpo_rnn.py @@ -51,6 +51,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward "time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation) + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -158,6 +160,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizer and learning rate scheduler if self.policy is not None and self.value is not None: if self.policy is self.value: @@ -265,10 +273,11 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states), **rnn}, role="policy") # sample stochastic actions - actions, log_prob, outputs = self.policy.act( - {"states": self._state_preprocessor(states), "alpha": self._alpha, **rnn}, role="policy" - ) - self._current_log_prob = log_prob + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, log_prob, outputs = self.policy.act( + {"states": self._state_preprocessor(states), "alpha": self._alpha, **rnn}, role="policy" + ) + self._current_log_prob = log_prob if self._rnn: self._rnn_final_states["policy"] = outputs.get("rnn", []) @@ -320,11 +329,12 @@ def record_transition( rewards = self._rewards_shaper(rewards, timestep, timesteps) # compute values - rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} - values, _, outputs = self.value.act( - {"states": self._state_preprocessor(states), "alpha": self._alpha, **rnn}, role="value" - ) - values = self._value_preprocessor(values, inverse=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} + values, _, outputs = self.value.act( + {"states": self._state_preprocessor(states), "alpha": self._alpha, **rnn}, role="value" + ) + values = self._value_preprocessor(values, inverse=True) # time-limit (truncation) bootstrapping if self._time_limit_bootstrap: @@ -467,7 +477,7 @@ def compute_gae( return returns, advantages # compute returns and advantages - with torch.no_grad(): + with torch.no_grad(), torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): self.value.train(False) rnn = {"rnn": self._rnn_initial_states["value"]} if self._rnn else {} last_values, _, _ = self.value.act( @@ -475,7 +485,7 @@ def compute_gae( role="value", ) self.value.train(True) - last_values = self._value_preprocessor(last_values, inverse=True) + last_values = self._value_preprocessor(last_values, inverse=True) values = self.memory.get_tensor_by_name("values") returns, advantages = compute_gae( @@ -548,64 +558,76 @@ def compute_gae( "terminated": sampled_dones, } - sampled_states = self._state_preprocessor(sampled_states, train=not epoch) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - _, next_log_prob, _ = self.policy.act( - {"states": sampled_states, "taken_actions": sampled_actions, "alpha": self._alpha, **rnn_policy}, - role="policy", - ) + sampled_states = self._state_preprocessor(sampled_states, train=not epoch) - # compute approximate KL divergence - with torch.no_grad(): - ratio = next_log_prob - sampled_log_prob - kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() - kl_divergences.append(kl_divergence) + _, next_log_prob, _ = self.policy.act( + { + "states": sampled_states, + "taken_actions": sampled_actions, + "alpha": self._alpha, + **rnn_policy, + }, + role="policy", + ) - # early stopping with KL divergence - if self._kl_threshold and kl_divergence > self._kl_threshold: - break + # compute approximate KL divergence + with torch.no_grad(): + ratio = next_log_prob - sampled_log_prob + kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean() + kl_divergences.append(kl_divergence) - # compute entropy loss - if self._entropy_loss_scale: - entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() - else: - entropy_loss = 0 + # early stopping with KL divergence + if self._kl_threshold and kl_divergence > self._kl_threshold: + break - # compute policy loss - ratio = torch.exp(next_log_prob - sampled_log_prob) - surrogate = sampled_advantages * ratio - surrogate_clipped = sampled_advantages * torch.clip( - ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip - ) + # compute entropy loss + if self._entropy_loss_scale: + entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean() + else: + entropy_loss = 0 - policy_loss = -torch.min(surrogate, surrogate_clipped).mean() + # compute policy loss + ratio = torch.exp(next_log_prob - sampled_log_prob) + surrogate = sampled_advantages * ratio + surrogate_clipped = sampled_advantages * torch.clip( + ratio, 1.0 - self._ratio_clip, 1.0 + self._ratio_clip + ) - # compute value loss - predicted_values, _, _ = self.value.act( - {"states": sampled_states, "alpha": self._alpha, **rnn_value}, role="value" - ) + policy_loss = -torch.min(surrogate, surrogate_clipped).mean() - if self._clip_predicted_values: - predicted_values = sampled_values + torch.clip( - predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + # compute value loss + predicted_values, _, _ = self.value.act( + {"states": sampled_states, "alpha": self._alpha, **rnn_value}, role="value" ) - value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) + + if self._clip_predicted_values: + predicted_values = sampled_values + torch.clip( + predicted_values - sampled_values, min=-self._value_clip, max=self._value_clip + ) + value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values) # optimization step self.optimizer.zero_grad() - (policy_loss + entropy_loss + value_loss).backward() + self.scaler.scale(policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() if self.policy is not self.value: self.value.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.optimizer) if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) else: nn.utils.clip_grad_norm_( itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip ) - self.optimizer.step() + + self.scaler.step(self.optimizer) + self.scaler.update() # update cumulative losses cumulative_policy_loss += policy_loss.item() diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py index df2a474c..fe7478bb 100644 --- a/skrl/agents/torch/sac/sac.py +++ b/skrl/agents/torch/sac/sac.py @@ -44,6 +44,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "base_directory": "", # base directory for the experiment "experiment_name": "", # experiment name @@ -160,6 +162,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # entropy if self._learn_entropy: self._target_entropy = self.cfg["target_entropy"] @@ -236,7 +244,8 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") # sample stochastic actions - actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") return actions, None, outputs @@ -344,79 +353,99 @@ def _update(self, timestep: int, timesteps: int) -> None: names=self._tensors_names, batch_size=self._batch_size )[0] - sampled_states = self._state_preprocessor(sampled_states, train=True) - sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) - - # compute target values - with torch.no_grad(): - next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy") - - target_q1_values, _, _ = self.target_critic_1.act( - {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1" + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + + # compute target values + with torch.no_grad(): + next_actions, next_log_prob, _ = self.policy.act({"states": sampled_next_states}, role="policy") + + target_q1_values, _, _ = self.target_critic_1.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1" + ) + target_q2_values, _, _ = self.target_critic_2.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2" + ) + target_q_values = ( + torch.min(target_q1_values, target_q2_values) - self._entropy_coefficient * next_log_prob + ) + target_values = ( + sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values + ) + + # compute critic loss + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_1" ) - target_q2_values, _, _ = self.target_critic_2.act( - {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2" + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_2" ) - target_q_values = ( - torch.min(target_q1_values, target_q2_values) - self._entropy_coefficient * next_log_prob - ) - target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values - # compute critic loss - critic_1_values, _, _ = self.critic_1.act( - {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_1" - ) - critic_2_values, _, _ = self.critic_2.act( - {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_2" - ) - - critic_loss = (F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values)) / 2 + critic_loss = ( + F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values) + ) / 2 # optimization step (critic) self.critic_optimizer.zero_grad() - critic_loss.backward() + self.scaler.scale(critic_loss).backward() + if config.torch.is_distributed: self.critic_1.reduce_parameters() self.critic_2.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.critic_optimizer) nn.utils.clip_grad_norm_( itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip ) - self.critic_optimizer.step() - # compute policy (actor) loss - actions, log_prob, _ = self.policy.act({"states": sampled_states}, role="policy") - critic_1_values, _, _ = self.critic_1.act( - {"states": sampled_states, "taken_actions": actions}, role="critic_1" - ) - critic_2_values, _, _ = self.critic_2.act( - {"states": sampled_states, "taken_actions": actions}, role="critic_2" - ) + self.scaler.step(self.critic_optimizer) + + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute policy (actor) loss + actions, log_prob, _ = self.policy.act({"states": sampled_states}, role="policy") + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_2" + ) - policy_loss = (self._entropy_coefficient * log_prob - torch.min(critic_1_values, critic_2_values)).mean() + policy_loss = ( + self._entropy_coefficient * log_prob - torch.min(critic_1_values, critic_2_values) + ).mean() # optimization step (policy) self.policy_optimizer.zero_grad() - policy_loss.backward() + self.scaler.scale(policy_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.policy_optimizer) nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) - self.policy_optimizer.step() + + self.scaler.step(self.policy_optimizer) # entropy learning if self._learn_entropy: - # compute entropy loss - entropy_loss = -(self.log_entropy_coefficient * (log_prob + self._target_entropy).detach()).mean() + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute entropy loss + entropy_loss = -(self.log_entropy_coefficient * (log_prob + self._target_entropy).detach()).mean() # optimization step (entropy) self.entropy_optimizer.zero_grad() - entropy_loss.backward() - self.entropy_optimizer.step() + self.scaler.scale(entropy_loss).backward() + self.scaler.step(self.entropy_optimizer) # compute entropy coefficient self._entropy_coefficient = torch.exp(self.log_entropy_coefficient.detach()) + self.scaler.update() # called once, after optimizers have been stepped + # update target networks self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak) self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak) diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py index 3160a27e..6cb27c3d 100644 --- a/skrl/agents/torch/sac/sac_rnn.py +++ b/skrl/agents/torch/sac/sac_rnn.py @@ -44,6 +44,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "base_directory": "", # base directory for the experiment "experiment_name": "", # experiment name @@ -160,6 +162,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # entropy if self._learn_entropy: self._target_entropy = self.cfg["target_entropy"] @@ -257,7 +265,8 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states), **rnn}, role="policy") # sample stochastic actions - actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states), **rnn}, role="policy") + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states), **rnn}, role="policy") if self._rnn: self._rnn_final_states["policy"] = outputs.get("rnn", []) @@ -394,81 +403,103 @@ def _update(self, timestep: int, timesteps: int) -> None: )[0] rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} - sampled_states = self._state_preprocessor(sampled_states, train=True) - sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) - - # compute target values - with torch.no_grad(): - next_actions, next_log_prob, _ = self.policy.act( - {"states": sampled_next_states, **rnn_policy}, role="policy" - ) - - target_q1_values, _, _ = self.target_critic_1.act( - {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_1" + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + + # compute target values + with torch.no_grad(): + next_actions, next_log_prob, _ = self.policy.act( + {"states": sampled_next_states, **rnn_policy}, role="policy" + ) + + target_q1_values, _, _ = self.target_critic_1.act( + {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, + role="target_critic_1", + ) + target_q2_values, _, _ = self.target_critic_2.act( + {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, + role="target_critic_2", + ) + target_q_values = ( + torch.min(target_q1_values, target_q2_values) - self._entropy_coefficient * next_log_prob + ) + target_values = ( + sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values + ) + + # compute critic loss + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_1" ) - target_q2_values, _, _ = self.target_critic_2.act( - {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_2" + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_2" ) - target_q_values = ( - torch.min(target_q1_values, target_q2_values) - self._entropy_coefficient * next_log_prob - ) - target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values - # compute critic loss - critic_1_values, _, _ = self.critic_1.act( - {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_1" - ) - critic_2_values, _, _ = self.critic_2.act( - {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_2" - ) - - critic_loss = (F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values)) / 2 + critic_loss = ( + F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values) + ) / 2 # optimization step (critic) self.critic_optimizer.zero_grad() - critic_loss.backward() + self.scaler.scale(critic_loss).backward() + if config.torch.is_distributed: self.critic_1.reduce_parameters() self.critic_2.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.critic_optimizer) nn.utils.clip_grad_norm_( itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip ) - self.critic_optimizer.step() - # compute policy (actor) loss - actions, log_prob, _ = self.policy.act({"states": sampled_states, **rnn_policy}, role="policy") - critic_1_values, _, _ = self.critic_1.act( - {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_1" - ) - critic_2_values, _, _ = self.critic_2.act( - {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_2" - ) + self.scaler.step(self.critic_optimizer) + + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute policy (actor) loss + actions, log_prob, _ = self.policy.act({"states": sampled_states, **rnn_policy}, role="policy") + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_1" + ) + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_2" + ) - policy_loss = (self._entropy_coefficient * log_prob - torch.min(critic_1_values, critic_2_values)).mean() + policy_loss = ( + self._entropy_coefficient * log_prob - torch.min(critic_1_values, critic_2_values) + ).mean() # optimization step (policy) self.policy_optimizer.zero_grad() - policy_loss.backward() + self.scaler.scale(policy_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.policy_optimizer) nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) - self.policy_optimizer.step() + + self.scaler.step(self.policy_optimizer) # entropy learning if self._learn_entropy: - # compute entropy loss - entropy_loss = -(self.log_entropy_coefficient * (log_prob + self._target_entropy).detach()).mean() + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute entropy loss + entropy_loss = -(self.log_entropy_coefficient * (log_prob + self._target_entropy).detach()).mean() # optimization step (entropy) self.entropy_optimizer.zero_grad() - entropy_loss.backward() - self.entropy_optimizer.step() + self.scaler.scale(entropy_loss).backward() + self.scaler.step(self.entropy_optimizer) # compute entropy coefficient self._entropy_coefficient = torch.exp(self.log_entropy_coefficient.detach()) + self.scaler.update() # called once, after optimizers have been stepped + # update target networks self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak) self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak) diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py index 7e5b7862..a178eb6b 100644 --- a/skrl/agents/torch/td3/td3.py +++ b/skrl/agents/torch/td3/td3.py @@ -49,6 +49,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -178,6 +180,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) @@ -240,7 +248,8 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") # sample deterministic actions - actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") # add exloration noise if self._exploration_noise is not None: @@ -380,79 +389,94 @@ def _update(self, timestep: int, timesteps: int) -> None: names=self._tensors_names, batch_size=self._batch_size )[0] - sampled_states = self._state_preprocessor(sampled_states, train=True) - sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) - - with torch.no_grad(): - # target policy smoothing - next_actions, _, _ = self.target_policy.act({"states": sampled_next_states}, role="target_policy") - if self._smooth_regularization_noise is not None: - noises = torch.clamp( - self._smooth_regularization_noise.sample(next_actions.shape), - min=-self._smooth_regularization_clip, - max=self._smooth_regularization_clip, + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + + with torch.no_grad(): + # target policy smoothing + next_actions, _, _ = self.target_policy.act({"states": sampled_next_states}, role="target_policy") + if self._smooth_regularization_noise is not None: + noises = torch.clamp( + self._smooth_regularization_noise.sample(next_actions.shape), + min=-self._smooth_regularization_clip, + max=self._smooth_regularization_clip, + ) + next_actions.add_(noises) + next_actions.clamp_(min=self.clip_actions_min, max=self.clip_actions_max) + + # compute target values + target_q1_values, _, _ = self.target_critic_1.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1" + ) + target_q2_values, _, _ = self.target_critic_2.act( + {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2" + ) + target_q_values = torch.min(target_q1_values, target_q2_values) + target_values = ( + sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values ) - next_actions.add_(noises) - next_actions.clamp_(min=self.clip_actions_min, max=self.clip_actions_max) - # compute target values - target_q1_values, _, _ = self.target_critic_1.act( - {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_1" + # compute critic loss + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_1" ) - target_q2_values, _, _ = self.target_critic_2.act( - {"states": sampled_next_states, "taken_actions": next_actions}, role="target_critic_2" + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_2" ) - target_q_values = torch.min(target_q1_values, target_q2_values) - target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values - # compute critic loss - critic_1_values, _, _ = self.critic_1.act( - {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_1" - ) - critic_2_values, _, _ = self.critic_2.act( - {"states": sampled_states, "taken_actions": sampled_actions}, role="critic_2" - ) - - critic_loss = F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values) + critic_loss = F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values) # optimization step (critic) self.critic_optimizer.zero_grad() - critic_loss.backward() + self.scaler.scale(critic_loss).backward() + if config.torch.is_distributed: self.critic_1.reduce_parameters() self.critic_2.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.critic_optimizer) nn.utils.clip_grad_norm_( itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip ) - self.critic_optimizer.step() + + self.scaler.step(self.critic_optimizer) # delayed update self._critic_update_counter += 1 if not self._critic_update_counter % self._policy_delay: - # compute policy (actor) loss - actions, _, _ = self.policy.act({"states": sampled_states}, role="policy") - critic_values, _, _ = self.critic_1.act( - {"states": sampled_states, "taken_actions": actions}, role="critic_1" - ) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute policy (actor) loss + actions, _, _ = self.policy.act({"states": sampled_states}, role="policy") + critic_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions}, role="critic_1" + ) - policy_loss = -critic_values.mean() + policy_loss = -critic_values.mean() # optimization step (policy) self.policy_optimizer.zero_grad() - policy_loss.backward() + self.scaler.scale(policy_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.policy_optimizer) nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) - self.policy_optimizer.step() + + self.scaler.step(self.policy_optimizer) # update target networks self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak) self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak) self.target_policy.update_parameters(self.policy, polyak=self._polyak) + self.scaler.update() # called once, after optimizers have been stepped + # update learning rate if self._learning_rate_scheduler: self.policy_scheduler.step() diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py index 39d1aeed..36636172 100644 --- a/skrl/agents/torch/td3/td3_rnn.py +++ b/skrl/agents/torch/td3/td3_rnn.py @@ -49,6 +49,8 @@ "rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward + "mixed_precision": False, # enable automatic mixed precision for higher performance + "experiment": { "directory": "", # experiment's parent directory "experiment_name": "", # experiment name @@ -178,6 +180,12 @@ def __init__( self._rewards_shaper = self.cfg["rewards_shaper"] + self._mixed_precision = self.cfg["mixed_precision"] + + # set up automatic mixed precision + self._device_type = torch.device(device).type + self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision) + # set up optimizers and learning rate schedulers if self.policy is not None and self.critic_1 is not None and self.critic_2 is not None: self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._actor_learning_rate) @@ -261,7 +269,8 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens return self.policy.random_act({"states": self._state_preprocessor(states), **rnn}, role="policy") # sample deterministic actions - actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states), **rnn}, role="policy") + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + actions, _, outputs = self.policy.act({"states": self._state_preprocessor(states), **rnn}, role="policy") if self._rnn: self._rnn_final_states["policy"] = outputs.get("rnn", []) @@ -430,81 +439,98 @@ def _update(self, timestep: int, timesteps: int) -> None: )[0] rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} - sampled_states = self._state_preprocessor(sampled_states, train=True) - sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): - with torch.no_grad(): - # target policy smoothing - next_actions, _, _ = self.target_policy.act( - {"states": sampled_next_states, **rnn_policy}, role="target_policy" - ) - if self._smooth_regularization_noise is not None: - noises = torch.clamp( - self._smooth_regularization_noise.sample(next_actions.shape), - min=-self._smooth_regularization_clip, - max=self._smooth_regularization_clip, + sampled_states = self._state_preprocessor(sampled_states, train=True) + sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) + + with torch.no_grad(): + # target policy smoothing + next_actions, _, _ = self.target_policy.act( + {"states": sampled_next_states, **rnn_policy}, role="target_policy" + ) + if self._smooth_regularization_noise is not None: + noises = torch.clamp( + self._smooth_regularization_noise.sample(next_actions.shape), + min=-self._smooth_regularization_clip, + max=self._smooth_regularization_clip, + ) + next_actions.add_(noises) + next_actions.clamp_(min=self.clip_actions_min, max=self.clip_actions_max) + + # compute target values + target_q1_values, _, _ = self.target_critic_1.act( + {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, + role="target_critic_1", + ) + target_q2_values, _, _ = self.target_critic_2.act( + {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, + role="target_critic_2", + ) + target_q_values = torch.min(target_q1_values, target_q2_values) + target_values = ( + sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values ) - next_actions.add_(noises) - next_actions.clamp_(min=self.clip_actions_min, max=self.clip_actions_max) - # compute target values - target_q1_values, _, _ = self.target_critic_1.act( - {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_1" + # compute critic loss + critic_1_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_1" ) - target_q2_values, _, _ = self.target_critic_2.act( - {"states": sampled_next_states, "taken_actions": next_actions, **rnn_policy}, role="target_critic_2" + critic_2_values, _, _ = self.critic_2.act( + {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_2" ) - target_q_values = torch.min(target_q1_values, target_q2_values) - target_values = sampled_rewards + self._discount_factor * sampled_dones.logical_not() * target_q_values - - # compute critic loss - critic_1_values, _, _ = self.critic_1.act( - {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_1" - ) - critic_2_values, _, _ = self.critic_2.act( - {"states": sampled_states, "taken_actions": sampled_actions, **rnn_policy}, role="critic_2" - ) - critic_loss = F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values) + critic_loss = F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values) # optimization step (critic) self.critic_optimizer.zero_grad() - critic_loss.backward() + self.scaler.scale(critic_loss).backward() + if config.torch.is_distributed: self.critic_1.reduce_parameters() self.critic_2.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.critic_optimizer) nn.utils.clip_grad_norm_( itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip ) - self.critic_optimizer.step() + + self.scaler.step(self.critic_optimizer) # delayed update self._critic_update_counter += 1 if not self._critic_update_counter % self._policy_delay: - # compute policy (actor) loss - actions, _, _ = self.policy.act({"states": sampled_states, **rnn_policy}, role="policy") - critic_values, _, _ = self.critic_1.act( - {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_1" - ) + with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): + # compute policy (actor) loss + actions, _, _ = self.policy.act({"states": sampled_states, **rnn_policy}, role="policy") + critic_values, _, _ = self.critic_1.act( + {"states": sampled_states, "taken_actions": actions, **rnn_policy}, role="critic_1" + ) - policy_loss = -critic_values.mean() + policy_loss = -critic_values.mean() # optimization step (policy) self.policy_optimizer.zero_grad() - policy_loss.backward() + self.scaler.scale(policy_loss).backward() + if config.torch.is_distributed: self.policy.reduce_parameters() + if self._grad_norm_clip > 0: + self.scaler.unscale_(self.policy_optimizer) nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) - self.policy_optimizer.step() + + self.scaler.step(self.policy_optimizer) # update target networks self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak) self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak) self.target_policy.update_parameters(self.policy, polyak=self._polyak) + self.scaler.update() # called once, after optimizers have been stepped + # update learning rate if self._learning_rate_scheduler: self.policy_scheduler.step() diff --git a/skrl/utils/spaces/torch/spaces.py b/skrl/utils/spaces/torch/spaces.py index 1ee664be..fd8cabdb 100644 --- a/skrl/utils/spaces/torch/spaces.py +++ b/skrl/utils/spaces/torch/spaces.py @@ -106,7 +106,11 @@ def untensorize_space(space: spaces.Space, x: Any, squeeze_batch_dimension: bool # Box if isinstance(space, spaces.Box): if isinstance(x, torch.Tensor): - array = np.array(x.cpu().numpy(), dtype=space.dtype) + # avoid TypeError: Got unsupported ScalarType BFloat16 + if x.dtype == torch.bfloat16: + array = np.array(x.to(dtype=torch.float32).cpu().numpy(), dtype=space.dtype) + else: + array = np.array(x.cpu().numpy(), dtype=space.dtype) if squeeze_batch_dimension and array.shape[0] == 1: return array.reshape(space.shape) return array.reshape(-1, *space.shape) @@ -114,7 +118,11 @@ def untensorize_space(space: spaces.Space, x: Any, squeeze_batch_dimension: bool # Discrete elif isinstance(space, spaces.Discrete): if isinstance(x, torch.Tensor): - array = np.array(x.cpu().numpy(), dtype=space.dtype) + # avoid TypeError: Got unsupported ScalarType BFloat16 + if x.dtype == torch.bfloat16: + array = np.array(x.to(dtype=torch.float32).cpu().numpy(), dtype=space.dtype) + else: + array = np.array(x.cpu().numpy(), dtype=space.dtype) if squeeze_batch_dimension and array.shape[0] == 1: return array.item() return array.reshape(-1, 1) @@ -122,7 +130,11 @@ def untensorize_space(space: spaces.Space, x: Any, squeeze_batch_dimension: bool # MultiDiscrete elif isinstance(space, spaces.MultiDiscrete): if isinstance(x, torch.Tensor): - array = np.array(x.cpu().numpy(), dtype=space.dtype) + # avoid TypeError: Got unsupported ScalarType BFloat16 + if x.dtype == torch.bfloat16: + array = np.array(x.to(dtype=torch.float32).cpu().numpy(), dtype=space.dtype) + else: + array = np.array(x.cpu().numpy(), dtype=space.dtype) if squeeze_batch_dimension and array.shape[0] == 1: return array.reshape(space.nvec.shape) return array.reshape(-1, *space.nvec.shape) diff --git a/tests/jax/test_jax_agent_a2c.py b/tests/jax/test_jax_agent_a2c.py new file mode 100644 index 00000000..c90787fd --- /dev/null +++ b/tests/jax/test_jax_agent_a2c.py @@ -0,0 +1,185 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.a2c import A2C as Agent +from skrl.agents.jax.a2c import A2C_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + for k in default_config.keys(): + assert k in config + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + mini_batches=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + lambda_=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + value_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + grad_norm_clip=st.floats(min_value=0, max_value=1), + entropy_loss_scale=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + time_limit_bootstrap=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("separate", [True]) +@pytest.mark.parametrize("policy_structure", ["GaussianMixin", "CategoricalMixin"]) +def test_agent( + capsys, + device, + num_envs, + # model config + separate, + policy_structure, + # agent config + rollouts, + mini_batches, + discount_factor, + lambda_, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + value_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + entropy_loss_scale, + rewards_shaper, + time_limit_bootstrap, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + if policy_structure in ["GaussianMixin"]: + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + elif policy_structure == "CategoricalMixin": + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + if separate: + if policy_structure == "GaussianMixin": + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + elif policy_structure == "CategoricalMixin": + models["policy"] = categorical_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["value"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + else: + raise NotImplementedError + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "mini_batches": mini_batches, + "discount_factor": discount_factor, + "lambda": lambda_, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "value_preprocessor": value_preprocessor, + "value_preprocessor_kwargs": {"size": 1, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "entropy_loss_scale": entropy_loss_scale, + "rewards_shaper": rewards_shaper, + "time_limit_bootstrap": time_limit_bootstrap, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_cem.py b/tests/jax/test_jax_agent_cem.py new file mode 100644 index 00000000..9ae7d2e9 --- /dev/null +++ b/tests/jax/test_jax_agent_cem.py @@ -0,0 +1,142 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.cem import CEM as Agent +from skrl.agents.jax.cem import CEM_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import categorical_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + for k in default_config.keys(): + assert k in config + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + percentile=st.floats(min_value=0, max_value=1), + discount_factor=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + rollouts, + percentile, + discount_factor, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = categorical_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "percentile": percentile, + "discount_factor": discount_factor, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_ddpg.py b/tests/jax/test_jax_agent_ddpg.py new file mode 100644 index 00000000..631ac3f0 --- /dev/null +++ b/tests/jax/test_jax_agent_ddpg.py @@ -0,0 +1,199 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.ddpg import DDPG as Agent +from skrl.agents.jax.ddpg import DDPG_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.noises.jax import GaussianNoise, OrnsteinUhlenbeckNoise +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + actor_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + critic_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.integers(min_value=0, max_value=5), + learning_starts=st.integers(min_value=0, max_value=5), + grad_norm_clip=st.floats(min_value=0, max_value=1), + exploration=st.one_of(st.none(), st.just(OrnsteinUhlenbeckNoise), st.just(GaussianNoise)), + exploration_initial_scale=st.floats(min_value=0, max_value=1), + exploration_final_scale=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + actor_learning_rate, + critic_learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + exploration, + exploration_initial_scale, + exploration_final_scale, + exploration_timesteps, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["critic"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "actor_learning_rate": actor_learning_rate, + "critic_learning_rate": critic_learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "exploration": { + "initial_scale": exploration_initial_scale, + "final_scale": exploration_final_scale, + "timesteps": exploration_timesteps, + }, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + # noise + # - exploration + if exploration is None: + cfg["exploration"]["noise"] = None + elif exploration is OrnsteinUhlenbeckNoise: + cfg["exploration"]["noise"] = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=1.0, device=env.device) + elif exploration is GaussianNoise: + cfg["exploration"]["noise"] = GaussianNoise(mean=0, std=0.1, device=env.device) + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_ddqn.py b/tests/jax/test_jax_agent_ddqn.py new file mode 100644 index 00000000..0b5fddb2 --- /dev/null +++ b/tests/jax/test_jax_agent_ddqn.py @@ -0,0 +1,174 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.dqn import DDQN as Agent +from skrl.agents.jax.dqn import DDQN_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.integers(min_value=0, max_value=5), + update_interval=st.integers(min_value=1, max_value=3), + target_update_interval=st.integers(min_value=1, max_value=5), + exploration_initial_epsilon=st.floats(min_value=0, max_value=1), + exploration_final_epsilon=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + update_interval, + target_update_interval, + exploration_initial_epsilon, + exploration_final_epsilon, + exploration_timesteps, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "update_interval": update_interval, + "target_update_interval": target_update_interval, + "exploration": { + "initial_epsilon": exploration_initial_epsilon, + "final_epsilon": exploration_final_epsilon, + "timesteps": exploration_timesteps, + }, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_dqn.py b/tests/jax/test_jax_agent_dqn.py new file mode 100644 index 00000000..bfc93fda --- /dev/null +++ b/tests/jax/test_jax_agent_dqn.py @@ -0,0 +1,174 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.dqn import DQN as Agent +from skrl.agents.jax.dqn import DQN_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.integers(min_value=0, max_value=5), + update_interval=st.integers(min_value=1, max_value=3), + target_update_interval=st.integers(min_value=1, max_value=5), + exploration_initial_epsilon=st.floats(min_value=0, max_value=1), + exploration_final_epsilon=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + update_interval, + target_update_interval, + exploration_initial_epsilon, + exploration_final_epsilon, + exploration_timesteps, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "update_interval": update_interval, + "target_update_interval": target_update_interval, + "exploration": { + "initial_epsilon": exploration_initial_epsilon, + "final_epsilon": exploration_final_epsilon, + "timesteps": exploration_timesteps, + }, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_ppo.py b/tests/jax/test_jax_agent_ppo.py new file mode 100644 index 00000000..27bafaff --- /dev/null +++ b/tests/jax/test_jax_agent_ppo.py @@ -0,0 +1,207 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.ppo import PPO as Agent +from skrl.agents.jax.ppo import PPO_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + learning_epochs=st.integers(min_value=1, max_value=5), + mini_batches=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + lambda_=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + value_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + grad_norm_clip=st.floats(min_value=0, max_value=1), + ratio_clip=st.floats(min_value=0, max_value=1), + value_clip=st.floats(min_value=0, max_value=1), + clip_predicted_values=st.booleans(), + entropy_loss_scale=st.floats(min_value=0, max_value=1), + value_loss_scale=st.floats(min_value=0, max_value=1), + kl_threshold=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + time_limit_bootstrap=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("separate", [True]) +@pytest.mark.parametrize("policy_structure", ["GaussianMixin", "CategoricalMixin"]) +def test_agent( + capsys, + device, + num_envs, + # model config + separate, + policy_structure, + # agent config + rollouts, + learning_epochs, + mini_batches, + discount_factor, + lambda_, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + value_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + ratio_clip, + value_clip, + clip_predicted_values, + entropy_loss_scale, + value_loss_scale, + kl_threshold, + rewards_shaper, + time_limit_bootstrap, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + if policy_structure in ["GaussianMixin"]: + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + elif policy_structure == "CategoricalMixin": + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + if separate: + if policy_structure == "GaussianMixin": + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + elif policy_structure == "CategoricalMixin": + models["policy"] = categorical_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["value"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + else: + raise NotImplementedError + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "learning_epochs": learning_epochs, + "mini_batches": mini_batches, + "discount_factor": discount_factor, + "lambda": lambda_, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "value_preprocessor": value_preprocessor, + "value_preprocessor_kwargs": {"size": 1, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "ratio_clip": ratio_clip, + "value_clip": value_clip, + "clip_predicted_values": clip_predicted_values, + "entropy_loss_scale": entropy_loss_scale, + "value_loss_scale": value_loss_scale, + "kl_threshold": kl_threshold, + "rewards_shaper": rewards_shaper, + "time_limit_bootstrap": time_limit_bootstrap, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_rpo.py b/tests/jax/test_jax_agent_rpo.py new file mode 100644 index 00000000..6fec0f01 --- /dev/null +++ b/tests/jax/test_jax_agent_rpo.py @@ -0,0 +1,199 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.rpo import RPO as Agent +from skrl.agents.jax.rpo import RPO_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + learning_epochs=st.integers(min_value=1, max_value=5), + mini_batches=st.integers(min_value=1, max_value=5), + alpha=st.floats(min_value=0, max_value=1), + discount_factor=st.floats(min_value=0, max_value=1), + lambda_=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + value_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + grad_norm_clip=st.floats(min_value=0, max_value=1), + ratio_clip=st.floats(min_value=0, max_value=1), + value_clip=st.floats(min_value=0, max_value=1), + clip_predicted_values=st.booleans(), + entropy_loss_scale=st.floats(min_value=0, max_value=1), + value_loss_scale=st.floats(min_value=0, max_value=1), + kl_threshold=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + time_limit_bootstrap=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("separate", [True]) +@pytest.mark.parametrize("policy_structure", ["GaussianMixin"]) +def test_agent( + capsys, + device, + num_envs, + # model config + separate, + policy_structure, + # agent config + rollouts, + learning_epochs, + mini_batches, + alpha, + discount_factor, + lambda_, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + value_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + ratio_clip, + value_clip, + clip_predicted_values, + entropy_loss_scale, + value_loss_scale, + kl_threshold, + rewards_shaper, + time_limit_bootstrap, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + if separate: + if policy_structure == "GaussianMixin": + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["value"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + else: + raise NotImplementedError + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "learning_epochs": learning_epochs, + "mini_batches": mini_batches, + "alpha": alpha, + "discount_factor": discount_factor, + "lambda": lambda_, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "value_preprocessor": value_preprocessor, + "value_preprocessor_kwargs": {"size": 1, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "ratio_clip": ratio_clip, + "value_clip": value_clip, + "clip_predicted_values": clip_predicted_values, + "entropy_loss_scale": entropy_loss_scale, + "value_loss_scale": value_loss_scale, + "kl_threshold": kl_threshold, + "rewards_shaper": rewards_shaper, + "time_limit_bootstrap": time_limit_bootstrap, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_sac.py b/tests/jax/test_jax_agent_sac.py new file mode 100644 index 00000000..112257a4 --- /dev/null +++ b/tests/jax/test_jax_agent_sac.py @@ -0,0 +1,195 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.sac import SAC as Agent +from skrl.agents.jax.sac import SAC_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + actor_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + critic_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.integers(min_value=0, max_value=5), + learning_starts=st.integers(min_value=0, max_value=5), + grad_norm_clip=st.floats(min_value=0, max_value=1), + learn_entropy=st.booleans(), + entropy_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + initial_entropy_value=st.floats(min_value=0, max_value=1), + target_entropy=st.one_of(st.none(), st.floats(min_value=-1, max_value=1)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + actor_learning_rate, + critic_learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + learn_entropy, + entropy_learning_rate, + initial_entropy_value, + target_entropy, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "actor_learning_rate": actor_learning_rate, + "critic_learning_rate": critic_learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "learn_entropy": learn_entropy, + "entropy_learning_rate": entropy_learning_rate, + "initial_entropy_value": initial_entropy_value, + "target_entropy": target_entropy, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/jax/test_jax_agent_td3.py b/tests/jax/test_jax_agent_td3.py new file mode 100644 index 00000000..b1ec36f1 --- /dev/null +++ b/tests/jax/test_jax_agent_td3.py @@ -0,0 +1,230 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import optax + +from skrl.agents.jax.td3 import TD3 as Agent +from skrl.agents.jax.td3 import TD3_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.jax import wrap_env +from skrl.memories.jax import RandomMemory +from skrl.resources.noises.jax import GaussianNoise, OrnsteinUhlenbeckNoise +from skrl.resources.preprocessors.jax import RunningStandardScaler +from skrl.resources.schedulers.jax import KLAdaptiveLR +from skrl.trainers.jax import SequentialTrainer +from skrl.utils.model_instantiators.jax import deterministic_model +from skrl.utils.spaces.jax import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + actor_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + critic_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(optax.schedules.constant_schedule)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.integers(min_value=0, max_value=5), + learning_starts=st.integers(min_value=0, max_value=5), + grad_norm_clip=st.floats(min_value=0, max_value=1), + exploration=st.one_of(st.none(), st.just(OrnsteinUhlenbeckNoise), st.just(GaussianNoise)), + exploration_initial_scale=st.floats(min_value=0, max_value=1), + exploration_final_scale=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + policy_delay=st.integers(min_value=1, max_value=3), + smooth_regularization_noise=st.one_of(st.none(), st.just(OrnsteinUhlenbeckNoise), st.just(GaussianNoise)), + smooth_regularization_clip=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + actor_learning_rate, + critic_learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + exploration, + exploration_initial_scale, + exploration_final_scale, + exploration_timesteps, + policy_delay, + smooth_regularization_noise, + smooth_regularization_clip, + rewards_shaper, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + # instantiate models' state dict + for role, model in models.items(): + model.init_state_dict(role) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "actor_learning_rate": actor_learning_rate, + "critic_learning_rate": critic_learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "exploration": { + "initial_scale": exploration_initial_scale, + "final_scale": exploration_final_scale, + "timesteps": exploration_timesteps, + }, + "policy_delay": policy_delay, + "smooth_regularization_clip": smooth_regularization_clip, + "rewards_shaper": rewards_shaper, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "value" + ] = learning_rate_scheduler_kwargs_value + # noise + # - exploration + if exploration is None: + cfg["exploration"]["noise"] = None + elif exploration is OrnsteinUhlenbeckNoise: + cfg["exploration"]["noise"] = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=1.0, device=env.device) + elif exploration is GaussianNoise: + cfg["exploration"]["noise"] = GaussianNoise(mean=0, std=0.1, device=env.device) + # - regularization + if smooth_regularization_noise is None: + cfg["smooth_regularization_noise"] = None + elif smooth_regularization_noise is OrnsteinUhlenbeckNoise: + cfg["smooth_regularization_noise"] = OrnsteinUhlenbeckNoise( + theta=0.1, sigma=0.2, base_scale=1.0, device=env.device + ) + elif smooth_regularization_noise is GaussianNoise: + cfg["smooth_regularization_noise"] = GaussianNoise(mean=0, std=0.1, device=env.device) + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/torch/test_torch_agent_a2c.py b/tests/torch/test_torch_agent_a2c.py new file mode 100644 index 00000000..e04e34c2 --- /dev/null +++ b/tests/torch/test_torch_agent_a2c.py @@ -0,0 +1,216 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import torch + +from skrl.agents.torch.a2c import A2C as Agent +from skrl.agents.torch.a2c import A2C_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.resources.schedulers.torch import KLAdaptiveLR +from skrl.trainers.torch import SequentialTrainer +from skrl.utils.model_instantiators.torch import ( + categorical_model, + deterministic_model, + gaussian_model, + multivariate_gaussian_model, + shared_model, +) +from skrl.utils.spaces.torch import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + for k in default_config.keys(): + assert k in config + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + mini_batches=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + lambda_=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(torch.optim.lr_scheduler.ConstantLR)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + value_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + grad_norm_clip=st.floats(min_value=0, max_value=1), + entropy_loss_scale=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + time_limit_bootstrap=st.booleans(), + mixed_precision=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("separate", [True, False]) +@pytest.mark.parametrize("policy_structure", ["GaussianMixin", "MultivariateGaussianMixin", "CategoricalMixin"]) +def test_agent( + capsys, + device, + num_envs, + # model config + separate, + policy_structure, + # agent config + rollouts, + mini_batches, + discount_factor, + lambda_, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + value_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + entropy_loss_scale, + rewards_shaper, + time_limit_bootstrap, + mixed_precision, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + if policy_structure in ["GaussianMixin", "MultivariateGaussianMixin"]: + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + elif policy_structure == "CategoricalMixin": + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + if separate: + if policy_structure == "GaussianMixin": + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + elif policy_structure == "MultivariateGaussianMixin": + models["policy"] = multivariate_gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + elif policy_structure == "CategoricalMixin": + models["policy"] = categorical_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["value"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + else: + models["policy"] = shared_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + structure=[policy_structure, "DeterministicMixin"], + parameters=[ + { + "network": network, + "output": "ACTIONS", + }, + { + "network": network, + "output": "ONE", + }, + ], + roles=["policy", "value"], + ) + models["value"] = models["policy"] + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "mini_batches": mini_batches, + "discount_factor": discount_factor, + "lambda": lambda_, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "value_preprocessor": value_preprocessor, + "value_preprocessor_kwargs": {"size": 1, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "entropy_loss_scale": entropy_loss_scale, + "rewards_shaper": rewards_shaper, + "time_limit_bootstrap": time_limit_bootstrap, + "mixed_precision": mixed_precision, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "factor" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/torch/test_torch_agent_amp.py b/tests/torch/test_torch_agent_amp.py new file mode 100644 index 00000000..e1326010 --- /dev/null +++ b/tests/torch/test_torch_agent_amp.py @@ -0,0 +1,248 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import torch + +from skrl.agents.torch.amp import AMP as Agent +from skrl.agents.torch.amp import AMP_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.resources.schedulers.torch import KLAdaptiveLR +from skrl.trainers.torch import SequentialTrainer +from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model +from skrl.utils.spaces.torch import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def __init__(self, observation_space, action_space, num_envs, device, amp_observation_space): + super().__init__(observation_space, action_space, num_envs, device) + self.amp_observation_space = amp_observation_space + + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + def step(self, actions): + observations, rewards, terminated, truncated, info = super().step(actions) + info["terminate"] = torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) + info["amp_obs"] = sample_space(self.amp_observation_space, self.num_envs, backend="torch", device=self.device) + return observations, rewards, terminated, truncated, info + + def fetch_amp_obs_demo(self, num_samples): + return sample_space(self.amp_observation_space, num_samples, backend="torch", device=self.device) + + def reset_done(self): + return ({"obs": sample_space(self.observation_space, self.num_envs, backend="torch", device=self.device)},) + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + for k in default_config.keys(): + assert k in config + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + learning_epochs=st.integers(min_value=1, max_value=5), + mini_batches=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + lambda_=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(torch.optim.lr_scheduler.ConstantLR)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + value_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + amp_state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + grad_norm_clip=st.floats(min_value=0, max_value=1), + ratio_clip=st.floats(min_value=0, max_value=1), + value_clip=st.floats(min_value=0, max_value=1), + clip_predicted_values=st.booleans(), + entropy_loss_scale=st.floats(min_value=0, max_value=1), + value_loss_scale=st.floats(min_value=0, max_value=1), + discriminator_loss_scale=st.floats(min_value=0, max_value=1), + amp_batch_size=st.integers(min_value=1, max_value=5), + task_reward_weight=st.floats(min_value=0, max_value=1), + style_reward_weight=st.floats(min_value=0, max_value=1), + discriminator_batch_size=st.integers(min_value=0, max_value=5), + discriminator_reward_scale=st.floats(min_value=0, max_value=1), + discriminator_logit_regularization_scale=st.floats(min_value=0, max_value=1), + discriminator_gradient_penalty_scale=st.floats(min_value=0, max_value=1), + discriminator_weight_decay_scale=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + time_limit_bootstrap=st.booleans(), + mixed_precision=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("separate", [True]) +@pytest.mark.parametrize("policy_structure", ["GaussianMixin"]) +def test_agent( + capsys, + device, + num_envs, + # model config + separate, + policy_structure, + # agent config + rollouts, + learning_epochs, + mini_batches, + discount_factor, + lambda_, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + value_preprocessor, + amp_state_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + ratio_clip, + value_clip, + clip_predicted_values, + entropy_loss_scale, + value_loss_scale, + discriminator_loss_scale, + amp_batch_size, + task_reward_weight, + style_reward_weight, + discriminator_batch_size, + discriminator_reward_scale, + discriminator_logit_regularization_scale, + discriminator_gradient_penalty_scale, + discriminator_weight_decay_scale, + rewards_shaper, + time_limit_bootstrap, + mixed_precision, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + amp_observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(10,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device, amp_observation_space), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + if separate: + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["value"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["discriminator"] = deterministic_model( + observation_space=env.amp_observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + else: + raise NotADirectoryError + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "learning_epochs": learning_epochs, + "mini_batches": mini_batches, + "discount_factor": discount_factor, + "lambda": lambda_, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "value_preprocessor": value_preprocessor, + "value_preprocessor_kwargs": {"size": 1, "device": env.device}, + "amp_state_preprocessor": amp_state_preprocessor, + "amp_state_preprocessor_kwargs": {"size": env.amp_observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "ratio_clip": ratio_clip, + "value_clip": value_clip, + "clip_predicted_values": clip_predicted_values, + "entropy_loss_scale": entropy_loss_scale, + "value_loss_scale": value_loss_scale, + "discriminator_loss_scale": discriminator_loss_scale, + "amp_batch_size": amp_batch_size, + "task_reward_weight": task_reward_weight, + "style_reward_weight": style_reward_weight, + "discriminator_batch_size": discriminator_batch_size, + "discriminator_reward_scale": discriminator_reward_scale, + "discriminator_logit_regularization_scale": discriminator_logit_regularization_scale, + "discriminator_gradient_penalty_scale": discriminator_gradient_penalty_scale, + "discriminator_weight_decay_scale": discriminator_weight_decay_scale, + "rewards_shaper": rewards_shaper, + "time_limit_bootstrap": time_limit_bootstrap, + "mixed_precision": mixed_precision, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "factor" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + amp_observation_space=env.amp_observation_space, + motion_dataset=RandomMemory(memory_size=50, device=device), + reply_buffer=RandomMemory(memory_size=100, device=device), + collect_reference_motions=lambda num_samples: env.fetch_amp_obs_demo(num_samples), + collect_observation=lambda: env.reset_done()[0]["obs"], + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/torch/test_torch_agent_cem.py b/tests/torch/test_torch_agent_cem.py new file mode 100644 index 00000000..06804965 --- /dev/null +++ b/tests/torch/test_torch_agent_cem.py @@ -0,0 +1,142 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import torch + +from skrl.agents.torch.cem import CEM as Agent +from skrl.agents.torch.cem import CEM_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.resources.schedulers.torch import KLAdaptiveLR +from skrl.trainers.torch import SequentialTrainer +from skrl.utils.model_instantiators.torch import categorical_model +from skrl.utils.spaces.torch import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + for k in default_config.keys(): + assert k in config + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + percentile=st.floats(min_value=0, max_value=1), + discount_factor=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(torch.optim.lr_scheduler.ConstantLR)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + mixed_precision=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + rollouts, + percentile, + discount_factor, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + rewards_shaper, + mixed_precision, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = categorical_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "percentile": percentile, + "discount_factor": discount_factor, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "rewards_shaper": rewards_shaper, + "mixed_precision": mixed_precision, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "factor" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/torch/test_torch_agent_ddpg.py b/tests/torch/test_torch_agent_ddpg.py new file mode 100644 index 00000000..b50811cc --- /dev/null +++ b/tests/torch/test_torch_agent_ddpg.py @@ -0,0 +1,199 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import torch + +from skrl.agents.torch.ddpg import DDPG as Agent +from skrl.agents.torch.ddpg import DDPG_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.resources.noises.torch import GaussianNoise, OrnsteinUhlenbeckNoise +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.resources.schedulers.torch import KLAdaptiveLR +from skrl.trainers.torch import SequentialTrainer +from skrl.utils.model_instantiators.torch import deterministic_model +from skrl.utils.spaces.torch import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + actor_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + critic_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(torch.optim.lr_scheduler.ConstantLR)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.integers(min_value=0, max_value=5), + learning_starts=st.integers(min_value=0, max_value=5), + grad_norm_clip=st.floats(min_value=0, max_value=1), + exploration=st.one_of(st.none(), st.just(OrnsteinUhlenbeckNoise), st.just(GaussianNoise)), + exploration_initial_scale=st.floats(min_value=0, max_value=1), + exploration_final_scale=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + mixed_precision=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + actor_learning_rate, + critic_learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + exploration, + exploration_initial_scale, + exploration_final_scale, + exploration_timesteps, + rewards_shaper, + mixed_precision, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["critic"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "actor_learning_rate": actor_learning_rate, + "critic_learning_rate": critic_learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "exploration": { + "initial_scale": exploration_initial_scale, + "final_scale": exploration_final_scale, + "timesteps": exploration_timesteps, + }, + "rewards_shaper": rewards_shaper, + "mixed_precision": mixed_precision, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "factor" + ] = learning_rate_scheduler_kwargs_value + # noise + # - exploration + if exploration is None: + cfg["exploration"]["noise"] = None + elif exploration is OrnsteinUhlenbeckNoise: + cfg["exploration"]["noise"] = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=1.0, device=env.device) + elif exploration is GaussianNoise: + cfg["exploration"]["noise"] = GaussianNoise(mean=0, std=0.1, device=env.device) + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/torch/test_torch_agent_ddqn.py b/tests/torch/test_torch_agent_ddqn.py new file mode 100644 index 00000000..e914294b --- /dev/null +++ b/tests/torch/test_torch_agent_ddqn.py @@ -0,0 +1,174 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import torch + +from skrl.agents.torch.dqn import DDQN as Agent +from skrl.agents.torch.dqn import DDQN_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.resources.schedulers.torch import KLAdaptiveLR +from skrl.trainers.torch import SequentialTrainer +from skrl.utils.model_instantiators.torch import deterministic_model +from skrl.utils.spaces.torch import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(torch.optim.lr_scheduler.ConstantLR)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.integers(min_value=0, max_value=5), + learning_starts=st.integers(min_value=0, max_value=5), + update_interval=st.integers(min_value=1, max_value=3), + target_update_interval=st.integers(min_value=1, max_value=5), + exploration_initial_epsilon=st.floats(min_value=0, max_value=1), + exploration_final_epsilon=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + mixed_precision=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + update_interval, + target_update_interval, + exploration_initial_epsilon, + exploration_final_epsilon, + exploration_timesteps, + rewards_shaper, + mixed_precision, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "update_interval": update_interval, + "target_update_interval": target_update_interval, + "exploration": { + "initial_epsilon": exploration_initial_epsilon, + "final_epsilon": exploration_final_epsilon, + "timesteps": exploration_timesteps, + }, + "rewards_shaper": rewards_shaper, + "mixed_precision": mixed_precision, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "factor" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/torch/test_torch_agent_dqn.py b/tests/torch/test_torch_agent_dqn.py new file mode 100644 index 00000000..3dd9dd2c --- /dev/null +++ b/tests/torch/test_torch_agent_dqn.py @@ -0,0 +1,174 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import torch + +from skrl.agents.torch.dqn import DQN as Agent +from skrl.agents.torch.dqn import DQN_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.resources.schedulers.torch import KLAdaptiveLR +from skrl.trainers.torch import SequentialTrainer +from skrl.utils.model_instantiators.torch import deterministic_model +from skrl.utils.spaces.torch import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(torch.optim.lr_scheduler.ConstantLR)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.integers(min_value=0, max_value=5), + learning_starts=st.integers(min_value=0, max_value=5), + update_interval=st.integers(min_value=1, max_value=3), + target_update_interval=st.integers(min_value=1, max_value=5), + exploration_initial_epsilon=st.floats(min_value=0, max_value=1), + exploration_final_epsilon=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + mixed_precision=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + update_interval, + target_update_interval, + exploration_initial_epsilon, + exploration_final_epsilon, + exploration_timesteps, + rewards_shaper, + mixed_precision, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_q_network"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "update_interval": update_interval, + "target_update_interval": target_update_interval, + "exploration": { + "initial_epsilon": exploration_initial_epsilon, + "final_epsilon": exploration_final_epsilon, + "timesteps": exploration_timesteps, + }, + "rewards_shaper": rewards_shaper, + "mixed_precision": mixed_precision, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "factor" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/torch/test_torch_agent_ppo.py b/tests/torch/test_torch_agent_ppo.py new file mode 100644 index 00000000..30c0ab8d --- /dev/null +++ b/tests/torch/test_torch_agent_ppo.py @@ -0,0 +1,238 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import torch + +from skrl.agents.torch.ppo import PPO as Agent +from skrl.agents.torch.ppo import PPO_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.resources.schedulers.torch import KLAdaptiveLR +from skrl.trainers.torch import SequentialTrainer +from skrl.utils.model_instantiators.torch import ( + categorical_model, + deterministic_model, + gaussian_model, + multivariate_gaussian_model, + shared_model, +) +from skrl.utils.spaces.torch import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + learning_epochs=st.integers(min_value=1, max_value=5), + mini_batches=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + lambda_=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(torch.optim.lr_scheduler.ConstantLR)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + value_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + grad_norm_clip=st.floats(min_value=0, max_value=1), + ratio_clip=st.floats(min_value=0, max_value=1), + value_clip=st.floats(min_value=0, max_value=1), + clip_predicted_values=st.booleans(), + entropy_loss_scale=st.floats(min_value=0, max_value=1), + value_loss_scale=st.floats(min_value=0, max_value=1), + kl_threshold=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + time_limit_bootstrap=st.booleans(), + mixed_precision=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("separate", [True, False]) +@pytest.mark.parametrize("policy_structure", ["GaussianMixin", "MultivariateGaussianMixin", "CategoricalMixin"]) +def test_agent( + capsys, + device, + num_envs, + # model config + separate, + policy_structure, + # agent config + rollouts, + learning_epochs, + mini_batches, + discount_factor, + lambda_, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + value_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + ratio_clip, + value_clip, + clip_predicted_values, + entropy_loss_scale, + value_loss_scale, + kl_threshold, + rewards_shaper, + time_limit_bootstrap, + mixed_precision, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + if policy_structure in ["GaussianMixin", "MultivariateGaussianMixin"]: + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + elif policy_structure == "CategoricalMixin": + action_space = gymnasium.spaces.Discrete(3) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + if separate: + if policy_structure == "GaussianMixin": + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + elif policy_structure == "MultivariateGaussianMixin": + models["policy"] = multivariate_gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + elif policy_structure == "CategoricalMixin": + models["policy"] = categorical_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["value"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + else: + models["policy"] = shared_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + structure=[policy_structure, "DeterministicMixin"], + parameters=[ + { + "network": network, + "output": "ACTIONS", + }, + { + "network": network, + "output": "ONE", + }, + ], + roles=["policy", "value"], + ) + models["value"] = models["policy"] + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "learning_epochs": learning_epochs, + "mini_batches": mini_batches, + "discount_factor": discount_factor, + "lambda": lambda_, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "value_preprocessor": value_preprocessor, + "value_preprocessor_kwargs": {"size": 1, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "ratio_clip": ratio_clip, + "value_clip": value_clip, + "clip_predicted_values": clip_predicted_values, + "entropy_loss_scale": entropy_loss_scale, + "value_loss_scale": value_loss_scale, + "kl_threshold": kl_threshold, + "rewards_shaper": rewards_shaper, + "time_limit_bootstrap": time_limit_bootstrap, + "mixed_precision": mixed_precision, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "factor" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/torch/test_torch_agent_rpo.py b/tests/torch/test_torch_agent_rpo.py new file mode 100644 index 00000000..6388dce9 --- /dev/null +++ b/tests/torch/test_torch_agent_rpo.py @@ -0,0 +1,229 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import torch + +from skrl.agents.torch.rpo import RPO as Agent +from skrl.agents.torch.rpo import RPO_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.resources.schedulers.torch import KLAdaptiveLR +from skrl.trainers.torch import SequentialTrainer +from skrl.utils.model_instantiators.torch import ( + deterministic_model, + gaussian_model, + multivariate_gaussian_model, + shared_model, +) +from skrl.utils.spaces.torch import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + learning_epochs=st.integers(min_value=1, max_value=5), + mini_batches=st.integers(min_value=1, max_value=5), + alpha=st.floats(min_value=0, max_value=1), + discount_factor=st.floats(min_value=0, max_value=1), + lambda_=st.floats(min_value=0, max_value=1), + learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(torch.optim.lr_scheduler.ConstantLR)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + value_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + grad_norm_clip=st.floats(min_value=0, max_value=1), + ratio_clip=st.floats(min_value=0, max_value=1), + value_clip=st.floats(min_value=0, max_value=1), + clip_predicted_values=st.booleans(), + entropy_loss_scale=st.floats(min_value=0, max_value=1), + value_loss_scale=st.floats(min_value=0, max_value=1), + kl_threshold=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + time_limit_bootstrap=st.booleans(), + mixed_precision=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("separate", [True, False]) +@pytest.mark.parametrize("policy_structure", ["GaussianMixin", "MultivariateGaussianMixin"]) +def test_agent( + capsys, + device, + num_envs, + # model config + separate, + policy_structure, + # agent config + rollouts, + learning_epochs, + mini_batches, + alpha, + discount_factor, + lambda_, + learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + value_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + ratio_clip, + value_clip, + clip_predicted_values, + entropy_loss_scale, + value_loss_scale, + kl_threshold, + rewards_shaper, + time_limit_bootstrap, + mixed_precision, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + if separate: + if policy_structure == "GaussianMixin": + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + elif policy_structure == "MultivariateGaussianMixin": + models["policy"] = multivariate_gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["value"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + else: + models["policy"] = shared_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + structure=[policy_structure, "DeterministicMixin"], + parameters=[ + { + "network": network, + "output": "ACTIONS", + }, + { + "network": network, + "output": "ONE", + }, + ], + roles=["policy", "value"], + ) + models["value"] = models["policy"] + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "learning_epochs": learning_epochs, + "mini_batches": mini_batches, + "alpha": alpha, + "discount_factor": discount_factor, + "lambda": lambda_, + "learning_rate": learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "value_preprocessor": value_preprocessor, + "value_preprocessor_kwargs": {"size": 1, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "ratio_clip": ratio_clip, + "value_clip": value_clip, + "clip_predicted_values": clip_predicted_values, + "entropy_loss_scale": entropy_loss_scale, + "value_loss_scale": value_loss_scale, + "kl_threshold": kl_threshold, + "rewards_shaper": rewards_shaper, + "time_limit_bootstrap": time_limit_bootstrap, + "mixed_precision": mixed_precision, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "factor" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/torch/test_torch_agent_sac.py b/tests/torch/test_torch_agent_sac.py new file mode 100644 index 00000000..0060e215 --- /dev/null +++ b/tests/torch/test_torch_agent_sac.py @@ -0,0 +1,195 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import torch + +from skrl.agents.torch.sac import SAC as Agent +from skrl.agents.torch.sac import SAC_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.resources.schedulers.torch import KLAdaptiveLR +from skrl.trainers.torch import SequentialTrainer +from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model +from skrl.utils.spaces.torch import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + actor_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + critic_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(torch.optim.lr_scheduler.ConstantLR)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.integers(min_value=0, max_value=5), + learning_starts=st.integers(min_value=0, max_value=5), + grad_norm_clip=st.floats(min_value=0, max_value=1), + learn_entropy=st.booleans(), + entropy_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + initial_entropy_value=st.floats(min_value=0, max_value=1), + target_entropy=st.one_of(st.none(), st.floats(min_value=-1, max_value=1)), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + mixed_precision=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + actor_learning_rate, + critic_learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + learn_entropy, + entropy_learning_rate, + initial_entropy_value, + target_entropy, + rewards_shaper, + mixed_precision, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "actor_learning_rate": actor_learning_rate, + "critic_learning_rate": critic_learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "learn_entropy": learn_entropy, + "entropy_learning_rate": entropy_learning_rate, + "initial_entropy_value": initial_entropy_value, + "target_entropy": target_entropy, + "rewards_shaper": rewards_shaper, + "mixed_precision": mixed_precision, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "factor" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/torch/test_torch_agent_td3.py b/tests/torch/test_torch_agent_td3.py new file mode 100644 index 00000000..919945af --- /dev/null +++ b/tests/torch/test_torch_agent_td3.py @@ -0,0 +1,230 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import torch + +from skrl.agents.torch.td3 import TD3 as Agent +from skrl.agents.torch.td3 import TD3_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.resources.noises.torch import GaussianNoise, OrnsteinUhlenbeckNoise +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.resources.schedulers.torch import KLAdaptiveLR +from skrl.trainers.torch import SequentialTrainer +from skrl.utils.model_instantiators.torch import deterministic_model +from skrl.utils.spaces.torch import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + gradient_steps=st.integers(min_value=1, max_value=2), + batch_size=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + polyak=st.floats(min_value=0, max_value=1), + actor_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + critic_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(torch.optim.lr_scheduler.ConstantLR)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.integers(min_value=0, max_value=5), + learning_starts=st.integers(min_value=0, max_value=5), + grad_norm_clip=st.floats(min_value=0, max_value=1), + exploration=st.one_of(st.none(), st.just(OrnsteinUhlenbeckNoise), st.just(GaussianNoise)), + exploration_initial_scale=st.floats(min_value=0, max_value=1), + exploration_final_scale=st.floats(min_value=0, max_value=1), + exploration_timesteps=st.one_of(st.none(), st.integers(min_value=1, max_value=50)), + policy_delay=st.integers(min_value=1, max_value=3), + smooth_regularization_noise=st.one_of(st.none(), st.just(OrnsteinUhlenbeckNoise), st.just(GaussianNoise)), + smooth_regularization_clip=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + mixed_precision=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_agent( + capsys, + device, + num_envs, + # agent config + gradient_steps, + batch_size, + discount_factor, + polyak, + actor_learning_rate, + critic_learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + exploration, + exploration_initial_scale, + exploration_final_scale, + exploration_timesteps, + policy_delay, + smooth_regularization_noise, + smooth_regularization_clip, + rewards_shaper, + mixed_precision, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + models["policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["target_policy"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_1"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + models["target_critic_2"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + + # memory + memory = RandomMemory(memory_size=50, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "gradient_steps": gradient_steps, + "batch_size": batch_size, + "discount_factor": discount_factor, + "polyak": polyak, + "actor_learning_rate": actor_learning_rate, + "critic_learning_rate": critic_learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "exploration": { + "initial_scale": exploration_initial_scale, + "final_scale": exploration_final_scale, + "timesteps": exploration_timesteps, + }, + "policy_delay": policy_delay, + "smooth_regularization_clip": smooth_regularization_clip, + "rewards_shaper": rewards_shaper, + "mixed_precision": mixed_precision, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "factor" + ] = learning_rate_scheduler_kwargs_value + # noise + # - exploration + if exploration is None: + cfg["exploration"]["noise"] = None + elif exploration is OrnsteinUhlenbeckNoise: + cfg["exploration"]["noise"] = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=1.0, device=env.device) + elif exploration is GaussianNoise: + cfg["exploration"]["noise"] = GaussianNoise(mean=0, std=0.1, device=env.device) + # - regularization + if smooth_regularization_noise is None: + cfg["smooth_regularization_noise"] = None + elif smooth_regularization_noise is OrnsteinUhlenbeckNoise: + cfg["smooth_regularization_noise"] = OrnsteinUhlenbeckNoise( + theta=0.1, sigma=0.2, base_scale=1.0, device=env.device + ) + elif smooth_regularization_noise is GaussianNoise: + cfg["smooth_regularization_noise"] = GaussianNoise(mean=0, std=0.1, device=env.device) + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + _check_agent_config(cfg["exploration"], DEFAULT_CONFIG["exploration"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": 50, + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/torch/test_torch_agent_trpo.py b/tests/torch/test_torch_agent_trpo.py new file mode 100644 index 00000000..7f6d4cfd --- /dev/null +++ b/tests/torch/test_torch_agent_trpo.py @@ -0,0 +1,199 @@ +import hypothesis +import hypothesis.strategies as st +import pytest + +import gymnasium + +import torch + +from skrl.agents.torch.trpo import TRPO as Agent +from skrl.agents.torch.trpo import TRPO_DEFAULT_CONFIG as DEFAULT_CONFIG +from skrl.envs.wrappers.torch import wrap_env +from skrl.memories.torch import RandomMemory +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.resources.schedulers.torch import KLAdaptiveLR +from skrl.trainers.torch import SequentialTrainer +from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, multivariate_gaussian_model +from skrl.utils.spaces.torch import sample_space + +from ..utils import BaseEnv + + +class Env(BaseEnv): + def _sample_observation(self): + return sample_space(self.observation_space, self.num_envs, backend="numpy") + + +def _check_agent_config(config, default_config): + for k in config.keys(): + assert k in default_config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + for k in default_config.keys(): + assert k in config + if k == "experiment": + _check_agent_config(config["experiment"], default_config["experiment"]) + + +@hypothesis.given( + num_envs=st.integers(min_value=1, max_value=5), + rollouts=st.integers(min_value=1, max_value=5), + learning_epochs=st.integers(min_value=1, max_value=5), + mini_batches=st.integers(min_value=1, max_value=5), + discount_factor=st.floats(min_value=0, max_value=1), + lambda_=st.floats(min_value=0, max_value=1), + value_learning_rate=st.floats(min_value=1.0e-10, max_value=1), + learning_rate_scheduler=st.one_of(st.none(), st.just(KLAdaptiveLR), st.just(torch.optim.lr_scheduler.ConstantLR)), + learning_rate_scheduler_kwargs_value=st.floats(min_value=0.1, max_value=1), + state_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + value_preprocessor=st.one_of(st.none(), st.just(RunningStandardScaler)), + random_timesteps=st.just(0), + learning_starts=st.just(0), + grad_norm_clip=st.floats(min_value=0, max_value=1), + value_loss_scale=st.floats(min_value=0, max_value=1), + damping=st.floats(min_value=0, max_value=1), + max_kl_divergence=st.floats(min_value=0, max_value=1), + conjugate_gradient_steps=st.integers(min_value=1, max_value=5), + max_backtrack_steps=st.integers(min_value=1, max_value=5), + accept_ratio=st.floats(min_value=0, max_value=1), + step_fraction=st.floats(min_value=0, max_value=1), + rewards_shaper=st.one_of(st.none(), st.just(lambda rewards, *args, **kwargs: 0.5 * rewards)), + time_limit_bootstrap=st.booleans(), +) +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("policy_structure", ["GaussianMixin", "MultivariateGaussianMixin"]) +def test_agent( + capsys, + device, + num_envs, + # model config + policy_structure, + # agent config + rollouts, + learning_epochs, + mini_batches, + discount_factor, + lambda_, + value_learning_rate, + learning_rate_scheduler, + learning_rate_scheduler_kwargs_value, + state_preprocessor, + value_preprocessor, + random_timesteps, + learning_starts, + grad_norm_clip, + value_loss_scale, + damping, + max_kl_divergence, + conjugate_gradient_steps, + max_backtrack_steps, + accept_ratio, + step_fraction, + rewards_shaper, + time_limit_bootstrap, +): + # spaces + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(5,)) + action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(3,)) + + # env + env = wrap_env(Env(observation_space, action_space, num_envs, device), wrapper="gymnasium") + + # models + network = [ + { + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "elu", + } + ] + models = {} + if policy_structure == "GaussianMixin": + models["policy"] = gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + elif policy_structure == "MultivariateGaussianMixin": + models["policy"] = multivariate_gaussian_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ACTIONS", + ) + models["value"] = deterministic_model( + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + network=network, + output="ONE", + ) + + # memory + memory = RandomMemory(memory_size=rollouts, num_envs=env.num_envs, device=env.device) + + # agent + cfg = { + "rollouts": rollouts, + "learning_epochs": learning_epochs, + "mini_batches": mini_batches, + "discount_factor": discount_factor, + "lambda": lambda_, + "value_learning_rate": value_learning_rate, + "learning_rate_scheduler": learning_rate_scheduler, + "learning_rate_scheduler_kwargs": {}, + "state_preprocessor": state_preprocessor, + "state_preprocessor_kwargs": {"size": env.observation_space, "device": env.device}, + "value_preprocessor": value_preprocessor, + "value_preprocessor_kwargs": {"size": 1, "device": env.device}, + "random_timesteps": random_timesteps, + "learning_starts": learning_starts, + "grad_norm_clip": grad_norm_clip, + "value_loss_scale": value_loss_scale, + "damping": damping, + "max_kl_divergence": max_kl_divergence, + "conjugate_gradient_steps": conjugate_gradient_steps, + "max_backtrack_steps": max_backtrack_steps, + "accept_ratio": accept_ratio, + "step_fraction": step_fraction, + "rewards_shaper": rewards_shaper, + "time_limit_bootstrap": time_limit_bootstrap, + "experiment": { + "directory": "", + "experiment_name": "", + "write_interval": 0, + "checkpoint_interval": 0, + "store_separately": False, + "wandb": False, + "wandb_kwargs": {}, + }, + } + cfg["learning_rate_scheduler_kwargs"][ + "kl_threshold" if learning_rate_scheduler is KLAdaptiveLR else "factor" + ] = learning_rate_scheduler_kwargs_value + _check_agent_config(cfg, DEFAULT_CONFIG) + _check_agent_config(cfg["experiment"], DEFAULT_CONFIG["experiment"]) + agent = Agent( + models=models, + memory=memory, + cfg=cfg, + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device, + ) + + # trainer + cfg_trainer = { + "timesteps": int(5 * rollouts), + "headless": True, + "disable_progressbar": True, + "close_environment_at_exit": False, + } + trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) + + trainer.train() diff --git a/tests/utils.py b/tests/utils.py index 9c3b6ef2..890f6089 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,100 +1,36 @@ import random -import gymnasium as gym +import gymnasium -import torch +import numpy as np -class DummyEnv(gym.Env): - def __init__(self, num_envs, device="cpu"): - self.num_agents = 1 +class BaseEnv(gymnasium.Env): + def __init__(self, observation_space, action_space, num_envs, device): + self.device = device self.num_envs = num_envs - self.device = torch.device(device) - self.action_space = gym.spaces.Discrete(2) - self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,)) + self.action_space = action_space + self.observation_space = observation_space - def __getattr__(self, key): - if key in ["_spec_to_space", "observation_spec"]: - return lambda *args, **kwargs: None - return None + def _sample_observation(self): + raise NotImplementedError - def step(self, action): - observation = self.observation_space.sample() - reward = random.random() - terminated = random.random() > 0.95 - truncated = random.random() > 0.95 + def step(self, actions): + if self.num_envs == 1: + rewards = random.random() + terminated = random.random() > 0.95 + truncated = random.random() > 0.95 + else: + rewards = np.random.random((self.num_envs,)) + terminated = np.random.random((self.num_envs,)) > 0.95 + truncated = np.random.random((self.num_envs,)) > 0.95 - observation = torch.tensor(observation, dtype=torch.float32).view(self.num_envs, -1) - reward = torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1) - terminated = torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) - truncated = torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) - - return observation, reward, terminated, truncated, {} + return self._sample_observation(), rewards, terminated, truncated, {} def reset(self): - observation = self.observation_space.sample() - observation = torch.tensor(observation, dtype=torch.float32).view(self.num_envs, -1) - return observation, {} + return self._sample_observation(), {} def render(self, *args, **kwargs): pass def close(self, *args, **kwargs): pass - - -class _DummyBaseAgent: - def __init__(self): - pass - - def record_transition( - self, states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps - ): - pass - - def pre_interaction(self, timestep, timesteps): - pass - - def post_interaction(self, timestep, timesteps): - pass - - def set_running_mode(self, mode): - pass - - -class DummyAgent(_DummyBaseAgent): - def __init__(self): - super().__init__() - - def init(self, trainer_cfg=None): - pass - - def act(self, states, timestep, timesteps): - return torch.tensor([]), None, {} - - def record_transition( - self, states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps - ): - pass - - def pre_interaction(self, timestep, timesteps): - pass - - def post_interaction(self, timestep, timesteps): - pass - - -class DummyModel(torch.nn.Module): - def __init__(self): - super().__init__() - - self.device = torch.device("cpu") - self.layer = torch.nn.Linear(1, 1) - - def set_mode(self, *args, **kwargs): - pass - - def get_specification(self, *args, **kwargs): - return {} - - def act(self, *args, **kwargs): - return torch.tensor([]), None, {}