Skip to content

Commit

Permalink
Add maskable GraphPPO based on sb3_contrib.MaskablePPO
Browse files Browse the repository at this point in the history
  • Loading branch information
nhuet committed Dec 5, 2024
1 parent 2d45178 commit 9550a07
Show file tree
Hide file tree
Showing 12 changed files with 409 additions and 65 deletions.
28 changes: 24 additions & 4 deletions examples/gnn_sb3_jsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from skdecide.hub.domain.gym import GymDomain
from skdecide.hub.solver.stable_baselines import StableBaseline
from skdecide.hub.solver.stable_baselines.gnn import GraphPPO
from skdecide.hub.solver.stable_baselines.gnn.ppo_mask import MaskableGraphPPO
from skdecide.hub.space.gym import GymSpace, ListSpace
from skdecide.utils import rollout

Expand Down Expand Up @@ -40,6 +41,9 @@ def _state_step(
outcome.state = self._np_state2graph_state(outcome.state)
return outcome

def action_masks(self):
return self._gym_env.valid_action_mask()

def _get_applicable_actions_from(
self, memory: D.T_memory[D.T_state]
) -> D.T_agent[Space[D.T_event]]:
Expand Down Expand Up @@ -120,21 +124,37 @@ def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:
action_mode="task",
)


# random rollout
domain = GraphJspDomain(gym_env=jsp_env)
rollout(domain=domain, max_steps=jsp_env.total_tasks_without_dummies, num_episodes=1)

# solve with sb3-PPO-GNN
# solve with sb3-GraphPPO
domain_factory = lambda: GraphJspDomain(gym_env=jsp_env)
with StableBaseline(
domain_factory=domain_factory,
algo_class=GraphPPO,
baselines_policy="GraphInputPolicy",
learn_config={"total_timesteps": 100},
# batch_size=1,
# normalize_advantage=False
) as solver:

solver.solve()
rollout(domain=domain_factory(), solver=solver, max_steps=100, num_episodes=1)

# solver with sb3-MaskableGraphPPO
domain_factory = lambda: GraphJspDomain(gym_env=jsp_env)
with StableBaseline(
domain_factory=domain_factory,
algo_class=MaskableGraphPPO,
baselines_policy="GraphInputPolicy",
learn_config={"total_timesteps": 100},
use_action_masking=True,
) as solver:

solver.solve()
rollout(
domain=domain_factory(),
solver=solver,
max_steps=100,
num_episodes=1,
use_action_masking=True,
)
35 changes: 33 additions & 2 deletions examples/gnn_sb3_maze.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Optional

import numpy as np
import numpy.typing as npt
from gymnasium.spaces import Box, Discrete, Graph, GraphInstance

from skdecide.builders.domain import Renderable, UnrestrictedActions
Expand All @@ -10,6 +11,7 @@
from skdecide.hub.domain.maze.maze import DEFAULT_MAZE, Action, State
from skdecide.hub.solver.stable_baselines import StableBaseline
from skdecide.hub.solver.stable_baselines.gnn import GraphPPO
from skdecide.hub.solver.stable_baselines.gnn.ppo_mask import MaskableGraphPPO
from skdecide.hub.space.gym import GymSpace, ListSpace
from skdecide.utils import rollout

Expand Down Expand Up @@ -157,6 +159,18 @@ def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any:
maze_memory = self._graph2mazestate(memory)
self.maze_domain._render_from(memory=maze_memory, **kwargs)

def action_masks(self) -> npt.NDArray[bool]:
mazestate_memory = self._graph2mazestate(self._memory)
return np.array(
[
self._graph2mazestate(
self._get_next_state(action=action, memory=self._memory)
)
!= mazestate_memory
for action in self._get_action_space().get_elements()
]
)


MAZE = """
+-+-+-+-+o+-+-+--+-+-+
Expand Down Expand Up @@ -196,9 +210,26 @@ def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any:
algo_class=GraphPPO,
baselines_policy="GraphInputPolicy",
learn_config={"total_timesteps": 100},
# batch_size=1,
# normalize_advantage=False
) as solver:

solver.solve()
rollout(domain=domain_factory(), solver=solver, max_steps=max_steps, num_episodes=1)

# solver with sb3-MaskableGraphPPO
domain_factory = lambda: GraphMaze(maze_str=MAZE)
with StableBaseline(
domain_factory=domain_factory,
algo_class=MaskableGraphPPO,
baselines_policy="GraphInputPolicy",
learn_config={"total_timesteps": 100},
use_action_masking=True,
) as solver:

solver.solve()
rollout(
domain=domain_factory(),
solver=solver,
max_steps=100,
num_episodes=1,
use_action_masking=True,
)
20 changes: 17 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ numpy = { version = "^1.20.1", optional = true }
matplotlib = { version = ">=3.3.4", optional = true }
joblib = { version = ">=1.0.1", optional = true }
stable-baselines3 = { version = ">=2.0.0", optional = true }
sb3_contrib = { version = ">=2.3", optional = true }
ray = { extras = ["rllib"], version = ">=2.9.0, <2.38", optional = true }
discrete-optimization = { version = ">=0.5.0" }
openap = { version = ">=1.5", python = ">=3.9", optional = true }
Expand Down Expand Up @@ -105,6 +106,7 @@ solvers = [
"joblib",
"ray",
"stable-baselines3",
"sb3_contrib",
"unified-planning",
"up-tamer",
"up-fast-downward",
Expand All @@ -122,6 +124,7 @@ all = [
"joblib",
"ray",
"stable-baselines3",
"sb3_contrib",
"openap",
"pygeodesy",
"unified-planning",
Expand Down
10 changes: 10 additions & 0 deletions skdecide/hub/domain/gym/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,10 @@ def __init__(self, domain: Domain, unwrap_spaces: bool = True) -> None:
domain.get_action_space()
) # assumes all actions are always applicable

@property
def domain(self) -> Domain:
return self._domain

def step(self, action):
"""Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for
calling `reset()` to reset this environment's state.
Expand Down Expand Up @@ -1280,6 +1284,8 @@ def unwrapped(self):
class AsGymnasiumEnv(EnvCompatibility):
"""This class wraps a scikit-decide domain as a gymnasium environment."""

env: AsLegacyGymV21Env

def __init__(
self,
domain: Domain,
Expand All @@ -1288,3 +1294,7 @@ def __init__(
) -> None:
legacy_env = AsLegacyGymV21Env(domain=domain, unwrap_spaces=unwrap_spaces)
super().__init__(old_env=legacy_env, render_mode=render_mode)

@property
def domain(self) -> Domain:
return self.env.domain
40 changes: 38 additions & 2 deletions skdecide/hub/solver/stable_baselines/gnn/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import torch as th
import torch_geometric as thg
from gymnasium import spaces
from sb3_contrib.common.maskable.buffers import (
MaskableRolloutBuffer,
MaskableRolloutBufferSamples,
)
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.type_aliases import (
Expand All @@ -19,6 +23,7 @@

class GraphRolloutBuffer(RolloutBuffer):
observations: Union[list[spaces.GraphInstance], list[list[spaces.GraphInstance]]]
tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]

def __init__(
self,
Expand Down Expand Up @@ -96,9 +101,8 @@ def get(
for vec_obs in self.observations:
self.raw_observations.extend(vec_obs)
self.observations = self.raw_observations
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]

for tensor in _tensor_names:
for tensor in self.tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True

Expand Down Expand Up @@ -130,3 +134,35 @@ def _get_samples(
return RolloutBufferSamples(
selected_observations, *tuple(map(self.to_torch, data))
)


class MaskableGraphRolloutBuffer(GraphRolloutBuffer, MaskableRolloutBuffer):

tensor_names = [
"actions",
"values",
"log_probs",
"advantages",
"returns",
"action_masks",
]

def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
"""
:param action_masks: Masks applied to constrain the choice of possible actions.
"""
if action_masks is not None:
self.action_masks[self.pos] = action_masks.reshape(
(self.n_envs, self.mask_dims)
)

super().add(*args, **kwargs)

def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> MaskableRolloutBufferSamples:
samples_wo_action_masks = super()._get_samples(batch_inds=batch_inds, env=env)
return MaskableRolloutBufferSamples(
*samples_wo_action_masks,
action_masks=self.action_masks[batch_inds].reshape(-1, self.mask_dims),
)
Loading

0 comments on commit 9550a07

Please sign in to comment.