-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for MultiDiscrete and MultiBinary action spaces in PPO (#30)
* Added support for MultiDiscrete action space to PPO * Added support for MultiBinary action spaces as discrete action spaces with two choices * Added tests for PPO with MultiDiscrete and MultiBinary action spaces * Moved the padding comment * Fixed type errors * Replaced | by Union in type hint to support python < 3.10 * Update ruff * Rename variables * Add more comments and pre-compute variables * Check that actions are not outside action space * [ci skip] Update version --------- Co-authored-by: Antonin Raffin <[email protected]>
- Loading branch information
Showing
7 changed files
with
135 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.11.0 | ||
0.12.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from dataclasses import dataclass | ||
from typing import Dict, Optional | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import pytest | ||
|
||
from sbx import PPO | ||
|
||
BOX_SPACE_FLOAT32 = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) | ||
|
||
|
||
@dataclass | ||
class DummyEnv(gym.Env): | ||
observation_space: gym.spaces.Space | ||
action_space: gym.spaces.Space | ||
|
||
def step(self, action): | ||
assert action in self.action_space | ||
return self.observation_space.sample(), 0.0, False, False, {} | ||
|
||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): | ||
if seed is not None: | ||
super().reset(seed=seed) | ||
return self.observation_space.sample(), {} | ||
|
||
|
||
class DummyMultiDiscreteAction(DummyEnv): | ||
def __init__(self): | ||
super().__init__( | ||
BOX_SPACE_FLOAT32, | ||
gym.spaces.MultiDiscrete([2, 3]), | ||
) | ||
|
||
|
||
class DummyMultiBinaryAction(DummyEnv): | ||
def __init__(self): | ||
super().__init__( | ||
BOX_SPACE_FLOAT32, | ||
gym.spaces.MultiBinary(2), | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultiDiscreteAction(), DummyMultiBinaryAction()]) | ||
def test_ppo_action_spaces(env): | ||
model = PPO("MlpPolicy", env, n_steps=32, batch_size=16) | ||
model.learn(64) | ||
|
||
|
||
def test_ppo_multidim_discrete_not_supported(): | ||
env = DummyEnv(BOX_SPACE_FLOAT32, gym.spaces.MultiDiscrete([[2, 3]])) | ||
with pytest.raises( | ||
AssertionError, | ||
match=r"Only one-dimensional MultiDiscrete action spaces are supported, but found MultiDiscrete\(.*\).", | ||
): | ||
PPO("MlpPolicy", env) | ||
|
||
|
||
def test_ppo_multidim_binary_not_supported(): | ||
env = DummyEnv(BOX_SPACE_FLOAT32, gym.spaces.MultiBinary([2, 3])) | ||
with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"): | ||
PPO("MlpPolicy", env) |