Skip to content

Commit

Permalink
Support for MultiDiscrete and MultiBinary action spaces in PPO (#30)
Browse files Browse the repository at this point in the history
* Added support for MultiDiscrete action space to PPO

* Added support for MultiBinary action spaces as discrete action spaces with two choices

* Added tests for PPO with MultiDiscrete and MultiBinary action spaces

* Moved the padding comment

* Fixed type errors

* Replaced | by Union in type hint to support python < 3.10

* Update ruff

* Rename variables

* Add more comments and pre-compute variables

* Check that actions are not outside action space

* [ci skip] Update version

---------

Co-authored-by: Antonin Raffin <[email protected]>
  • Loading branch information
jan1854 and araffin authored Feb 28, 2024
1 parent e564074 commit db6120b
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero

Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ import gymnasium as gym

from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ

env = gym.make("Pendulum-v1")
env = gym.make("Pendulum-v1", render_mode="human")

model = TQC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)

vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
for _ in range(1000):
vec_env.render()
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()

vec_env.close()
```
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
line-length = 127
# Assume Python 3.8
target-version = "py38"

[tool.ruff.lint]
# See https://beta.ruff.rs/docs/rules/
select = ["E", "F", "B", "UP", "C90", "RUF"]
# Ignore explicit stacklevel`
ignore = ["B028"]


[tool.ruff.mccabe]
[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 15

Expand Down
72 changes: 63 additions & 9 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, List, Optional, Union
from dataclasses import field
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import flax.linen as nn
import gymnasium as gym
Expand Down Expand Up @@ -37,26 +38,61 @@ class Actor(nn.Module):
action_dim: int
n_units: int = 256
log_std_init: float = 0.0
continuous: bool = True
activation_fn: Callable = nn.tanh
# For Discrete, MultiDiscrete and MultiBinary actions
num_discrete_choices: Optional[Union[int, Sequence[int]]] = None
# For MultiDiscrete
max_num_choices: int = 0
split_indices: np.ndarray = field(default_factory=lambda: np.array([]))

def get_std(self):
def get_std(self) -> jnp.ndarray:
# Make it work with gSDE
return jnp.array(0.0)

def __post_init__(self) -> None:
# For MultiDiscrete
if isinstance(self.num_discrete_choices, np.ndarray):
self.max_num_choices = max(self.num_discrete_choices)
# np.cumsum(...) gives the correct indices at which to split the flatten logits
self.split_indices = np.cumsum(self.num_discrete_choices[:-1])
super().__post_init__()

@nn.compact
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
x = Flatten()(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
mean = nn.Dense(self.action_dim)(x)
if self.continuous:
action_logits = nn.Dense(self.action_dim)(x)
if self.num_discrete_choices is None:
# Continuous actions
log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,))
dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std))
dist = tfd.MultivariateNormalDiag(loc=action_logits, scale_diag=jnp.exp(log_std))
elif isinstance(self.num_discrete_choices, int):
dist = tfd.Categorical(logits=action_logits)
else:
dist = tfd.Categorical(logits=mean)
# Split action_logits = (batch_size, total_choices=sum(self.num_discrete_choices))
action_logits = jnp.split(action_logits, self.split_indices, axis=1)
# Pad to the maximum number of choices (required by tfp.distributions.Categorical).
# Pad by -inf, so that the probability of these invalid actions is 0.
logits_padded = jnp.stack(
[
jnp.pad(
logit,
# logit is of shape (batch_size, n)
# only pad after dim=1, to max_num_choices - n
# pad_width=((before_dim_0, after_0), (before_dim_1, after_1))
pad_width=((0, 0), (0, self.max_num_choices - logit.shape[1])),
constant_values=-np.inf,
)
for logit in action_logits
],
axis=1,
)
dist = tfp.distributions.Independent(
tfp.distributions.Categorical(logits=logits_padded), reinterpreted_batch_ndims=1
)
return dist


Expand Down Expand Up @@ -123,12 +159,30 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
if isinstance(self.action_space, spaces.Box):
actor_kwargs = {
"action_dim": int(np.prod(self.action_space.shape)),
"continuous": True,
}
elif isinstance(self.action_space, spaces.Discrete):
actor_kwargs = {
"action_dim": int(self.action_space.n),
"continuous": False,
"num_discrete_choices": int(self.action_space.n),
}
elif isinstance(self.action_space, spaces.MultiDiscrete):
assert self.action_space.nvec.ndim == 1, (
f"Only one-dimensional MultiDiscrete action spaces are supported, "
f"but found MultiDiscrete({(self.action_space.nvec).tolist()})."
)
actor_kwargs = {
"action_dim": int(np.sum(self.action_space.nvec)),
"num_discrete_choices": self.action_space.nvec, # type: ignore[dict-item]
}
elif isinstance(self.action_space, spaces.MultiBinary):
assert isinstance(self.action_space.n, int), (
f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. "
"You can flatten it instead."
)
# Handle binary action spaces as discrete action spaces with two choices.
actor_kwargs = {
"action_dim": 2 * self.action_space.n,
"num_discrete_choices": 2 * np.ones(self.action_space.n, dtype=int),
}
else:
raise NotImplementedError(f"{self.action_space}")
Expand Down
4 changes: 2 additions & 2 deletions sbx/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def __init__(
supported_action_spaces=(
spaces.Box,
spaces.Discrete,
# spaces.MultiDiscrete,
# spaces.MultiBinary,
spaces.MultiDiscrete,
spaces.MultiBinary,
),
)

Expand Down
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.11.0
0.12.0
62 changes: 62 additions & 0 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from dataclasses import dataclass
from typing import Dict, Optional

import gymnasium as gym
import numpy as np
import pytest

from sbx import PPO

BOX_SPACE_FLOAT32 = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)


@dataclass
class DummyEnv(gym.Env):
observation_space: gym.spaces.Space
action_space: gym.spaces.Space

def step(self, action):
assert action in self.action_space
return self.observation_space.sample(), 0.0, False, False, {}

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
super().reset(seed=seed)
return self.observation_space.sample(), {}


class DummyMultiDiscreteAction(DummyEnv):
def __init__(self):
super().__init__(
BOX_SPACE_FLOAT32,
gym.spaces.MultiDiscrete([2, 3]),
)


class DummyMultiBinaryAction(DummyEnv):
def __init__(self):
super().__init__(
BOX_SPACE_FLOAT32,
gym.spaces.MultiBinary(2),
)


@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultiDiscreteAction(), DummyMultiBinaryAction()])
def test_ppo_action_spaces(env):
model = PPO("MlpPolicy", env, n_steps=32, batch_size=16)
model.learn(64)


def test_ppo_multidim_discrete_not_supported():
env = DummyEnv(BOX_SPACE_FLOAT32, gym.spaces.MultiDiscrete([[2, 3]]))
with pytest.raises(
AssertionError,
match=r"Only one-dimensional MultiDiscrete action spaces are supported, but found MultiDiscrete\(.*\).",
):
PPO("MlpPolicy", env)


def test_ppo_multidim_binary_not_supported():
env = DummyEnv(BOX_SPACE_FLOAT32, gym.spaces.MultiBinary([2, 3]))
with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"):
PPO("MlpPolicy", env)

0 comments on commit db6120b

Please sign in to comment.