Skip to content

Commit

Permalink
Merge pull request #147 from alessandroassirelli98/fix_paper_mismatch…
Browse files Browse the repository at this point in the history
…_develop

Fix DDPG and TD3 agents: move sampling inside gradient step loop to fix implementation mismatch
  • Loading branch information
Toni-SM authored Apr 17, 2024
2 parents 631613a + 94f18fa commit a3fb36c
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [1.2.0] - Unreleased
### Fixed
- Moved the batch sampling inside gradient step loop for DDPG and TD3. [SpinningUp description.](https://spinningup.openai.com/en/latest/algorithms/ddpg.html#pseudocode)

## [1.1.0] - 2024-02-12
### Added
- MultiCategorical mixin to operate MultiDiscrete action spaces
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ Learning algorithm

|
| :literal:`_update(...)`
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`a' \leftarrow \mu_{\theta_{target}}(s')`
| :math:`Q_{_{target}} \leftarrow Q_{\phi_{target}}(s', a')`
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ Learning algorithm

|
| :literal:`_update(...)`
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# target policy smoothing`
| :math:`a' \leftarrow \mu_{\theta_{target}}(s')`
| :math:`noise \leftarrow \text{clip}(` :guilabel:`smooth_regularization_noise` :math:`, -c, c) \qquad` with :math:`c` as :guilabel:`smooth_regularization_clip`
Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/jax/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,13 +384,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
6 changes: 3 additions & 3 deletions skrl/agents/jax/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,12 +422,12 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)
Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/torch/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/torch/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down

0 comments on commit a3fb36c

Please sign in to comment.