Skip to content

Commit

Permalink
Automatic mixed precision training in PyTorch (#243)
Browse files Browse the repository at this point in the history
* Update PPO mixed-precision implementation in torch

* Add A2C mixed precision support in torch

* Add AMP mixed precision support in torch

* Add CEM mixed precision support in torch

* Add DDPG mixed precision support in torch

* Add DQN and DDQN mixed precision support in torch

* Add RPO mixed precision support in torch

* Add SAC mixed precision support in torch

* Add TD3 mixed precision support in torch

* Update docs

* Add PPO test in torch

* Add agent tests in torch

* Add agent tests in jax

* Avoid TypeError: Got unsupported ScalarType BFloat16

* Update CHANGELOG
  • Loading branch information
Toni-SM authored Jan 1, 2025
1 parent 23b61dc commit 2882db4
Show file tree
Hide file tree
Showing 50 changed files with 4,917 additions and 648 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- `parse_device` static method in ML framework configuration (used in library components to set up the device)
- Model instantiator support for different shared model structures in PyTorch
- Support for other model types than Gaussian and Deterministic in runners
- Support for automatic mixed precision training in PyTorch

### Changed
- Call agent's `pre_interaction` method during evaluation
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Mixed precision
- Automatic mixed precision
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Mixed precision
- Automatic mixed precision
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/cem.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Mixed precision
- Automatic mixed precision
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- \-
- .. centered:: :math:`\square`
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Mixed precision
- Automatic mixed precision
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/ddqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Mixed precision
- Automatic mixed precision
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Mixed precision
- Automatic mixed precision
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Mixed precision
- Automatic mixed precision
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/rpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Mixed precision
- Automatic mixed precision
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Mixed precision
- Automatic mixed precision
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Mixed precision
- Automatic mixed precision
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/trpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Mixed precision
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
Expand Down
75 changes: 46 additions & 29 deletions skrl/agents/torch/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
"time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation)

"mixed_precision": False, # enable automatic mixed precision for higher performance

"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
Expand Down Expand Up @@ -142,6 +144,12 @@ def __init__(
self._rewards_shaper = self.cfg["rewards_shaper"]
self._time_limit_bootstrap = self.cfg["time_limit_bootstrap"]

self._mixed_precision = self.cfg["mixed_precision"]

# set up automatic mixed precision
self._device_type = torch.device(device).type
self.scaler = torch.cuda.amp.GradScaler(enabled=self._mixed_precision)

# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
if self.policy is self.value:
Expand Down Expand Up @@ -211,8 +219,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens
return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy")

# sample stochastic actions
actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy")
self._current_log_prob = log_prob
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy")
self._current_log_prob = log_prob

return actions, log_prob, outputs

Expand Down Expand Up @@ -261,8 +270,9 @@ def record_transition(
rewards = self._rewards_shaper(rewards, timestep, timesteps)

# compute values
values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value")
values = self._value_preprocessor(values, inverse=True)
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
values, _, _ = self.value.act({"states": self._state_preprocessor(states)}, role="value")
values = self._value_preprocessor(values, inverse=True)

# time-limit (truncation) bootstrapping
if self._time_limit_bootstrap:
Expand Down Expand Up @@ -375,13 +385,13 @@ def compute_gae(
return returns, advantages

# compute returns and advantages
with torch.no_grad():
with torch.no_grad(), torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):
self.value.train(False)
last_values, _, _ = self.value.act(
{"states": self._state_preprocessor(self._current_next_states.float())}, role="value"
)
self.value.train(True)
last_values = self._value_preprocessor(last_values, inverse=True)
last_values = self._value_preprocessor(last_values, inverse=True)

values = self.memory.get_tensor_by_name("values")
returns, advantages = compute_gae(
Expand Down Expand Up @@ -409,49 +419,56 @@ def compute_gae(
# mini-batches loop
for sampled_states, sampled_actions, sampled_log_prob, sampled_returns, sampled_advantages in sampled_batches:

sampled_states = self._state_preprocessor(sampled_states, train=True)
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision):

_, next_log_prob, _ = self.policy.act(
{"states": sampled_states, "taken_actions": sampled_actions}, role="policy"
)
sampled_states = self._state_preprocessor(sampled_states, train=True)

# compute approximate KL divergence for KLAdaptive learning rate scheduler
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
with torch.no_grad():
ratio = next_log_prob - sampled_log_prob
kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean()
kl_divergences.append(kl_divergence)
_, next_log_prob, _ = self.policy.act(
{"states": sampled_states, "taken_actions": sampled_actions}, role="policy"
)

# compute entropy loss
if self._entropy_loss_scale:
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean()
else:
entropy_loss = 0
# compute approximate KL divergence for KLAdaptive learning rate scheduler
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
with torch.no_grad():
ratio = next_log_prob - sampled_log_prob
kl_divergence = ((torch.exp(ratio) - 1) - ratio).mean()
kl_divergences.append(kl_divergence)

# compute entropy loss
if self._entropy_loss_scale:
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy(role="policy").mean()
else:
entropy_loss = 0

# compute policy loss
policy_loss = -(sampled_advantages * next_log_prob).mean()
# compute policy loss
policy_loss = -(sampled_advantages * next_log_prob).mean()

# compute value loss
predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value")
# compute value loss
predicted_values, _, _ = self.value.act({"states": sampled_states}, role="value")

value_loss = F.mse_loss(sampled_returns, predicted_values)
value_loss = F.mse_loss(sampled_returns, predicted_values)

# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
self.scaler.scale(policy_loss + entropy_loss + value_loss).backward()

if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()

if self._grad_norm_clip > 0:
self.scaler.unscale_(self.optimizer)
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
else:
nn.utils.clip_grad_norm_(
itertools.chain(self.policy.parameters(), self.value.parameters()), self._grad_norm_clip
)
self.optimizer.step()

self.scaler.step(self.optimizer)
self.scaler.update()

# update cumulative losses
cumulative_policy_loss += policy_loss.item()
Expand Down
Loading

0 comments on commit 2882db4

Please sign in to comment.