Skip to content

Commit

Permalink
Fix IPPO and MAPPO act method return values when JAX-NumPy backend is…
Browse files Browse the repository at this point in the history
… enabled (#193)
  • Loading branch information
Toni-SM authored Aug 29, 2024
1 parent 238641e commit 53b0243
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Catch TensorBoard summary iterator exceptions in `TensorboardFileIterator` postprocessing utils
- Fix automatic wrapper detection for Isaac Gym (previews), DeepMind and vectorized Gymnasium environments
- Fix vectorized/parallel environments `reset` method return values when called more than once
- IPPO and MAPPO `act` method return values when JAX-NumPy backend is enabled

## [1.2.0] - 2024-06-23
### Added
Expand Down
4 changes: 2 additions & 2 deletions skrl/multi_agents/jax/ippo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def act(self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int,
outputs = {uid: d[2] for uid, d in zip(self.possible_agents, data)}

if not self._jax: # numpy backend
actions = {jax.device_get(_actions) for _actions in actions}
log_prob = {jax.device_get(_log_prob) for _log_prob in log_prob}
actions = {uid: jax.device_get(_actions) for uid, _actions in actions.items()}
log_prob = {uid: jax.device_get(_log_prob) for uid, _log_prob in log_prob.items()}

self._current_log_prob = log_prob

Expand Down
4 changes: 2 additions & 2 deletions skrl/multi_agents/jax/mappo/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,8 @@ def act(self, states: Mapping[str, Union[np.ndarray, jax.Array]], timestep: int,
outputs = {uid: d[2] for uid, d in zip(self.possible_agents, data)}

if not self._jax: # numpy backend
actions = {jax.device_get(_actions) for _actions in actions}
log_prob = {jax.device_get(_log_prob) for _log_prob in log_prob}
actions = {uid: jax.device_get(_actions) for uid, _actions in actions.items()}
log_prob = {uid: jax.device_get(_log_prob) for uid, _log_prob in log_prob.items()}

self._current_log_prob = log_prob

Expand Down

0 comments on commit 53b0243

Please sign in to comment.