diff --git a/examples/gnn_sb3_jsp.py b/examples/gnn_sb3_jsp.py index 268a13a812..f9933a1f73 100644 --- a/examples/gnn_sb3_jsp.py +++ b/examples/gnn_sb3_jsp.py @@ -9,6 +9,7 @@ from skdecide.hub.domain.gym import GymDomain from skdecide.hub.solver.stable_baselines import StableBaseline from skdecide.hub.solver.stable_baselines.gnn import GraphPPO +from skdecide.hub.solver.stable_baselines.gnn.ppo_mask import MaskableGraphPPO from skdecide.hub.space.gym import GymSpace, ListSpace from skdecide.utils import rollout @@ -40,6 +41,9 @@ def _state_step( outcome.state = self._np_state2graph_state(outcome.state) return outcome + def action_masks(self): + return self._gym_env.valid_action_mask() + def _get_applicable_actions_from( self, memory: D.T_memory[D.T_state] ) -> D.T_agent[Space[D.T_event]]: @@ -120,21 +124,37 @@ def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any: action_mode="task", ) - # random rollout domain = GraphJspDomain(gym_env=jsp_env) rollout(domain=domain, max_steps=jsp_env.total_tasks_without_dummies, num_episodes=1) -# solve with sb3-PPO-GNN +# solve with sb3-GraphPPO domain_factory = lambda: GraphJspDomain(gym_env=jsp_env) with StableBaseline( domain_factory=domain_factory, algo_class=GraphPPO, baselines_policy="GraphInputPolicy", learn_config={"total_timesteps": 100}, - # batch_size=1, - # normalize_advantage=False ) as solver: solver.solve() rollout(domain=domain_factory(), solver=solver, max_steps=100, num_episodes=1) + +# solver with sb3-MaskableGraphPPO +domain_factory = lambda: GraphJspDomain(gym_env=jsp_env) +with StableBaseline( + domain_factory=domain_factory, + algo_class=MaskableGraphPPO, + baselines_policy="GraphInputPolicy", + learn_config={"total_timesteps": 100}, + use_action_masking=True, +) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + use_action_masking=True, + ) diff --git a/examples/gnn_sb3_maze.py b/examples/gnn_sb3_maze.py index 923b3b8a5d..6d5f9426ed 100644 --- a/examples/gnn_sb3_maze.py +++ b/examples/gnn_sb3_maze.py @@ -1,6 +1,7 @@ from typing import Any, Optional import numpy as np +import numpy.typing as npt from gymnasium.spaces import Box, Discrete, Graph, GraphInstance from skdecide.builders.domain import Renderable, UnrestrictedActions @@ -10,6 +11,7 @@ from skdecide.hub.domain.maze.maze import DEFAULT_MAZE, Action, State from skdecide.hub.solver.stable_baselines import StableBaseline from skdecide.hub.solver.stable_baselines.gnn import GraphPPO +from skdecide.hub.solver.stable_baselines.gnn.ppo_mask import MaskableGraphPPO from skdecide.hub.space.gym import GymSpace, ListSpace from skdecide.utils import rollout @@ -157,6 +159,18 @@ def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any: maze_memory = self._graph2mazestate(memory) self.maze_domain._render_from(memory=maze_memory, **kwargs) + def action_masks(self) -> npt.NDArray[bool]: + mazestate_memory = self._graph2mazestate(self._memory) + return np.array( + [ + self._graph2mazestate( + self._get_next_state(action=action, memory=self._memory) + ) + != mazestate_memory + for action in self._get_action_space().get_elements() + ] + ) + MAZE = """ +-+-+-+-+o+-+-+--+-+-+ @@ -196,9 +210,26 @@ def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any: algo_class=GraphPPO, baselines_policy="GraphInputPolicy", learn_config={"total_timesteps": 100}, - # batch_size=1, - # normalize_advantage=False ) as solver: solver.solve() rollout(domain=domain_factory(), solver=solver, max_steps=max_steps, num_episodes=1) + +# solver with sb3-MaskableGraphPPO +domain_factory = lambda: GraphMaze(maze_str=MAZE) +with StableBaseline( + domain_factory=domain_factory, + algo_class=MaskableGraphPPO, + baselines_policy="GraphInputPolicy", + learn_config={"total_timesteps": 100}, + use_action_masking=True, +) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + use_action_masking=True, + ) diff --git a/poetry.lock b/poetry.lock index 1cd3b25a20..fb48e7b8da 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5023,6 +5023,20 @@ files = [ {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, ] +[[package]] +name = "sb3-contrib" +version = "2.3.0" +description = "Contrib package of Stable Baselines3, experimental code." +optional = true +python-versions = ">=3.8" +files = [ + {file = "sb3_contrib-2.3.0-py3-none-any.whl", hash = "sha256:79dadda884f2242bbab11fdb06ccf00eeead2937f7ed09ba577f803f2e52c8e1"}, + {file = "sb3_contrib-2.3.0.tar.gz", hash = "sha256:9755ef9efa08e43afb9c4e79564ed594eb8bc497256920403a7e269a72216221"}, +] + +[package.dependencies] +stable-baselines3 = ">=2.3.0,<3.0" + [[package]] name = "scikit-image" version = "0.24.0" @@ -6159,11 +6173,11 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [extras] -all = ["cartopy", "gymnasium", "joblib", "matplotlib", "numpy", "openap", "pyRDDLGym", "pyRDDLGym", "pyRDDLGym-gurobi", "pyRDDLGym-jax", "pyRDDLGym-rl", "pyRDDLGym-rl", "pygeodesy", "pygrib", "pygrib", "ray", "rddlrepository", "scipy", "stable-baselines3", "torch-geometric", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] +all = ["cartopy", "gymnasium", "joblib", "matplotlib", "numpy", "openap", "pyRDDLGym", "pyRDDLGym", "pyRDDLGym-gurobi", "pyRDDLGym-jax", "pyRDDLGym-rl", "pyRDDLGym-rl", "pygeodesy", "pygrib", "pygrib", "ray", "rddlrepository", "sb3_contrib", "scipy", "stable-baselines3", "torch-geometric", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] domains = ["cartopy", "gymnasium", "matplotlib", "numpy", "openap", "pyRDDLGym", "pyRDDLGym", "pyRDDLGym-rl", "pyRDDLGym-rl", "pygeodesy", "pygrib", "pygrib", "rddlrepository", "scipy", "unified-planning"] -solvers = ["gymnasium", "joblib", "numpy", "pyRDDLGym-gurobi", "pyRDDLGym-jax", "ray", "scipy", "stable-baselines3", "torch-geometric", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] +solvers = ["gymnasium", "joblib", "numpy", "pyRDDLGym-gurobi", "pyRDDLGym-jax", "ray", "sb3_contrib", "scipy", "stable-baselines3", "torch-geometric", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "2b17ad02ae15987e4983858da614b1d71e27e1ef9ff9ed1e4b95d32e1ed02813" +content-hash = "c04dc99a6ada248e1fb56a0285536bce688415553f183d03237ad440d3559dcf" diff --git a/pyproject.toml b/pyproject.toml index 44b62ba7b6..1d60c5666d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ numpy = { version = "^1.20.1", optional = true } matplotlib = { version = ">=3.3.4", optional = true } joblib = { version = ">=1.0.1", optional = true } stable-baselines3 = { version = ">=2.0.0", optional = true } +sb3_contrib = { version = ">=2.3", optional = true } ray = { extras = ["rllib"], version = ">=2.9.0, <2.38", optional = true } discrete-optimization = { version = ">=0.5.0" } openap = { version = ">=1.5", python = ">=3.9", optional = true } @@ -105,6 +106,7 @@ solvers = [ "joblib", "ray", "stable-baselines3", + "sb3_contrib", "unified-planning", "up-tamer", "up-fast-downward", @@ -122,6 +124,7 @@ all = [ "joblib", "ray", "stable-baselines3", + "sb3_contrib", "openap", "pygeodesy", "unified-planning", diff --git a/skdecide/hub/domain/gym/gym.py b/skdecide/hub/domain/gym/gym.py index 8cae2bea2a..9285bbbcbe 100644 --- a/skdecide/hub/domain/gym/gym.py +++ b/skdecide/hub/domain/gym/gym.py @@ -1175,6 +1175,10 @@ def __init__(self, domain: Domain, unwrap_spaces: bool = True) -> None: domain.get_action_space() ) # assumes all actions are always applicable + @property + def domain(self) -> Domain: + return self._domain + def step(self, action): """Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` to reset this environment's state. @@ -1280,6 +1284,8 @@ def unwrapped(self): class AsGymnasiumEnv(EnvCompatibility): """This class wraps a scikit-decide domain as a gymnasium environment.""" + env: AsLegacyGymV21Env + def __init__( self, domain: Domain, @@ -1288,3 +1294,7 @@ def __init__( ) -> None: legacy_env = AsLegacyGymV21Env(domain=domain, unwrap_spaces=unwrap_spaces) super().__init__(old_env=legacy_env, render_mode=render_mode) + + @property + def domain(self) -> Domain: + return self.env.domain diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/buffers.py b/skdecide/hub/solver/stable_baselines/gnn/common/buffers.py index 602e784ba8..864d825081 100644 --- a/skdecide/hub/solver/stable_baselines/gnn/common/buffers.py +++ b/skdecide/hub/solver/stable_baselines/gnn/common/buffers.py @@ -5,6 +5,10 @@ import torch as th import torch_geometric as thg from gymnasium import spaces +from sb3_contrib.common.maskable.buffers import ( + MaskableRolloutBuffer, + MaskableRolloutBufferSamples, +) from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.type_aliases import ( @@ -19,6 +23,7 @@ class GraphRolloutBuffer(RolloutBuffer): observations: Union[list[spaces.GraphInstance], list[list[spaces.GraphInstance]]] + tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] def __init__( self, @@ -96,9 +101,8 @@ def get( for vec_obs in self.observations: self.raw_observations.extend(vec_obs) self.observations = self.raw_observations - _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] - for tensor in _tensor_names: + for tensor in self.tensor_names: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.generator_ready = True @@ -130,3 +134,35 @@ def _get_samples( return RolloutBufferSamples( selected_observations, *tuple(map(self.to_torch, data)) ) + + +class MaskableGraphRolloutBuffer(GraphRolloutBuffer, MaskableRolloutBuffer): + + tensor_names = [ + "actions", + "values", + "log_probs", + "advantages", + "returns", + "action_masks", + ] + + def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None: + """ + :param action_masks: Masks applied to constrain the choice of possible actions. + """ + if action_masks is not None: + self.action_masks[self.pos] = action_masks.reshape( + (self.n_envs, self.mask_dims) + ) + + super().add(*args, **kwargs) + + def _get_samples( + self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None + ) -> MaskableRolloutBufferSamples: + samples_wo_action_masks = super()._get_samples(batch_inds=batch_inds, env=env) + return MaskableRolloutBufferSamples( + *samples_wo_action_masks, + action_masks=self.action_masks[batch_inds].reshape(-1, self.mask_dims), + ) diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/on_policy_algorithm.py b/skdecide/hub/solver/stable_baselines/gnn/common/on_policy_algorithm.py index 6824c75e21..78942155d4 100644 --- a/skdecide/hub/solver/stable_baselines/gnn/common/on_policy_algorithm.py +++ b/skdecide/hub/solver/stable_baselines/gnn/common/on_policy_algorithm.py @@ -3,6 +3,11 @@ import numpy as np import torch as th from gymnasium import spaces +from sb3_contrib.common.maskable.buffers import ( + MaskableDictRolloutBuffer, + MaskableRolloutBuffer, +) +from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm @@ -19,6 +24,13 @@ class GraphOnPolicyAlgorithm(OnPolicyAlgorithm): """Base class for On-Policy algorithms (ex: A2C/PPO) with graph observations.""" + support_action_masking = False + """Whether this algorithm supports action masking. + + Useful to share the code between algorithms. + + """ + def __init__( self, policy: Union[str, type[ActorCriticPolicy]], @@ -51,17 +63,21 @@ def collect_rollouts( callback: BaseCallback, rollout_buffer: RolloutBuffer, n_rollout_steps: int, + use_masking: bool = False, ) -> bool: """ Collect experiences using the current policy and fill a ``RolloutBuffer``. The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning. + This method is largely identical to the implementation found in the parent class and MaskablePPO. + :param env: The training environment :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) :param rollout_buffer: Buffer to fill with rollouts :param n_rollout_steps: Number of experiences to collect per environment + :param use_masking: Whether to use invalid action masks during training :return: True if function returned with at least `n_rollout_steps` collected, False if callback terminated rollout prematurely. """ @@ -70,7 +86,23 @@ def collect_rollouts( self.policy.set_training_mode(False) n_steps = 0 + action_masks = None rollout_buffer.reset() + + if ( + use_masking + and self.support_action_masking + and not is_masking_supported(env) + ): + raise ValueError( + "Environment does not support action masking. Consider using ActionMasker wrapper" + ) + + if use_masking and not self.support_action_masking: + raise ValueError( + f"The algorithm {self.__class__.__name__} does not support action masking." + ) + # Sample new weights for the state dependent exploration if self.use_sde: self.policy.reset_noise(env.num_envs) @@ -94,7 +126,16 @@ def collect_rollouts( ) else: obs_tensor = obs_as_tensor(self._last_obs, self.device) - actions, values, log_probs = self.policy(obs_tensor) + + if use_masking and self.support_action_masking: + action_masks = get_action_masks(env) + + if self.support_action_masking: + actions, values, log_probs = self.policy( + obs_tensor, action_masks=action_masks + ) + else: + actions, values, log_probs = self.policy(obs_tensor) actions = actions.cpu().numpy() # Rescale and perform action @@ -143,14 +184,27 @@ def collect_rollouts( terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type] rewards[idx] += self.gamma * terminal_value - rollout_buffer.add( - self._last_obs, # type: ignore[arg-type] - actions, - rewards, - self._last_episode_starts, # type: ignore[arg-type] - values, - log_probs, - ) + if isinstance( + rollout_buffer, (MaskableRolloutBuffer, MaskableDictRolloutBuffer) + ): + rollout_buffer.add( + self._last_obs, # type: ignore[arg-type] + actions, + rewards, + self._last_episode_starts, # type: ignore[arg-type] + values, + log_probs, + action_masks=action_masks, + ) + else: + rollout_buffer.add( + self._last_obs, # type: ignore[arg-type] + actions, + rewards, + self._last_episode_starts, # type: ignore[arg-type] + values, + log_probs, + ) self._last_obs = new_obs # type: ignore[assignment] self._last_episode_starts = dones diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/policies.py b/skdecide/hub/solver/stable_baselines/gnn/common/policies.py index fcb6279fd3..8b710ed4e6 100644 --- a/skdecide/hub/solver/stable_baselines/gnn/common/policies.py +++ b/skdecide/hub/solver/stable_baselines/gnn/common/policies.py @@ -2,10 +2,13 @@ from typing import Any, Optional, Tuple, Union import gymnasium as gym +import numpy as np import torch as th import torch_geometric as thg +from sb3_contrib.common.maskable.distributions import MaskableDistribution +from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from stable_baselines3.common.distributions import Distribution -from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy from stable_baselines3.common.torch_layers import BaseFeaturesExtractor from stable_baselines3.common.type_aliases import Schedule @@ -15,47 +18,7 @@ PyTorchGraphObs = Union[thg.data.Data, list[thg.data.Data]] -class GNNActorCriticPolicy(ActorCriticPolicy): - def __init__( - self, - observation_space: gym.spaces.Graph, - action_space: gym.spaces.Space, - lr_schedule: Schedule, - net_arch: Optional[list[Union[int, dict[str, list[int]]]]] = None, - activation_fn: type[th.nn.Module] = th.nn.Tanh, - ortho_init: bool = True, - use_sde: bool = False, - log_std_init: float = 0.0, - full_std: bool = True, - use_expln: bool = False, - squash_output: bool = False, - features_extractor_class: type[BaseFeaturesExtractor] = GraphFeaturesExtractor, - features_extractor_kwargs: Optional[dict[str, Any]] = None, - share_features_extractor: bool = True, - normalize_images: bool = True, - optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[dict[str, Any]] = None, - ): - super().__init__( - observation_space=observation_space, - action_space=action_space, - lr_schedule=lr_schedule, - net_arch=net_arch, - activation_fn=activation_fn, - ortho_init=ortho_init, - use_sde=use_sde, - log_std_init=log_std_init, - full_std=full_std, - use_expln=use_expln, - squash_output=squash_output, - features_extractor_class=features_extractor_class, - features_extractor_kwargs=features_extractor_kwargs, - share_features_extractor=share_features_extractor, - normalize_images=normalize_images, - optimizer_class=optimizer_class, - optimizer_kwargs=optimizer_kwargs, - ) - +class _BaseGNNActorCriticPolicy(BasePolicy): def extract_features( self, obs: thg.data.Data, @@ -110,3 +73,89 @@ def predict_values(self, obs: thg.data.Data) -> th.Tensor: features = self.vf_features_extractor(obs) latent_vf = self.mlp_extractor.forward_critic(features) return self.value_net(latent_vf) + + +class GNNActorCriticPolicy(_BaseGNNActorCriticPolicy, ActorCriticPolicy): + def __init__( + self, + observation_space: gym.spaces.Graph, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[list[Union[int, dict[str, list[int]]]]] = None, + activation_fn: type[th.nn.Module] = th.nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: type[BaseFeaturesExtractor] = GraphFeaturesExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__( + observation_space=observation_space, + action_space=action_space, + lr_schedule=lr_schedule, + net_arch=net_arch, + activation_fn=activation_fn, + ortho_init=ortho_init, + use_sde=use_sde, + log_std_init=log_std_init, + full_std=full_std, + use_expln=use_expln, + squash_output=squash_output, + features_extractor_class=features_extractor_class, + features_extractor_kwargs=features_extractor_kwargs, + share_features_extractor=share_features_extractor, + normalize_images=normalize_images, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + ) + + +class MaskableGNNActorCriticPolicy( + _BaseGNNActorCriticPolicy, MaskableActorCriticPolicy +): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[th.nn.Module] = th.nn.Tanh, + ortho_init: bool = True, + features_extractor_class: type[BaseFeaturesExtractor] = GraphFeaturesExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__( + observation_space=observation_space, + action_space=action_space, + lr_schedule=lr_schedule, + net_arch=net_arch, + activation_fn=activation_fn, + ortho_init=ortho_init, + features_extractor_class=features_extractor_class, + features_extractor_kwargs=features_extractor_kwargs, + share_features_extractor=share_features_extractor, + normalize_images=normalize_images, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + ) + + def get_distribution( + self, obs: thg.data.Data, action_masks: Optional[np.ndarray] = None + ) -> MaskableDistribution: + features = self.pi_features_extractor(obs) + latent_pi = self.mlp_extractor.forward_actor(features) + distribution = self._get_action_dist_from_latent(latent_pi) + if action_masks is not None: + distribution.apply_masking(action_masks) + return distribution diff --git a/skdecide/hub/solver/stable_baselines/gnn/ppo_mask/__init__.py b/skdecide/hub/solver/stable_baselines/gnn/ppo_mask/__init__.py new file mode 100644 index 0000000000..3838cc0bce --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/ppo_mask/__init__.py @@ -0,0 +1,4 @@ +from ..common.policies import MaskableGNNActorCriticPolicy +from .ppo_mask import MaskableGraphPPO + +GraphInputPolicy = MaskableGNNActorCriticPolicy diff --git a/skdecide/hub/solver/stable_baselines/gnn/ppo_mask/ppo_mask.py b/skdecide/hub/solver/stable_baselines/gnn/ppo_mask/ppo_mask.py new file mode 100644 index 0000000000..2d1ca3e7c3 --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/ppo_mask/ppo_mask.py @@ -0,0 +1,37 @@ +from typing import ClassVar, Optional, Union + +from sb3_contrib import MaskablePPO +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy +from stable_baselines3.common.type_aliases import GymEnv + +from ..common.buffers import MaskableGraphRolloutBuffer +from ..common.on_policy_algorithm import GraphOnPolicyAlgorithm +from ..common.policies import MaskableGNNActorCriticPolicy + + +class MaskableGraphPPO(GraphOnPolicyAlgorithm, MaskablePPO): + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { + "GraphInputPolicy": MaskableGNNActorCriticPolicy, + } + + support_action_masking = True + + def __init__( + self, + policy: Union[str, type[ActorCriticPolicy]], + env: GymEnv, + rollout_buffer_class: Optional[type[RolloutBuffer]] = None, + **kwargs, + ): + + # Use proper default rollout buffer class + if rollout_buffer_class is None: + rollout_buffer_class = MaskableGraphRolloutBuffer + + super().__init__( + policy=policy, + env=env, + rollout_buffer_class=rollout_buffer_class, + **kwargs, + ) diff --git a/skdecide/hub/solver/stable_baselines/stable_baselines.py b/skdecide/hub/solver/stable_baselines/stable_baselines.py index 784eb04321..5eabfbb771 100644 --- a/skdecide/hub/solver/stable_baselines/stable_baselines.py +++ b/skdecide/hub/solver/stable_baselines/stable_baselines.py @@ -7,11 +7,14 @@ from collections.abc import Callable from typing import Any, Optional, Union +import gymnasium as gym +import numpy as np from discrete_optimization.generic_tools.hyperparameters.hyperparameter import ( CategoricalHyperparameter, FloatHyperparameter, IntegerHyperparameter, ) +from sb3_contrib.common.wrappers import ActionMasker from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.callbacks import BaseCallback, ConvertCallback @@ -89,6 +92,7 @@ def __init__( baselines_policy: Union[str, type[BasePolicy]], learn_config: Optional[dict[str, Any]] = None, callback: Callable[[StableBaseline], bool] = lambda solver: False, + use_action_masking: bool = False, **kwargs: Any, ) -> None: """Initialize StableBaselines. @@ -100,6 +104,7 @@ def __init__( baselines_policy: The class of Baselines policy network (stable_baselines3.common.policies or str) to use. learn_config: the kwargs passed to sb3 algo's `learn()` method callback: function called at each solver iteration. If returning true, the solve process stops. + use_action_masking: if True, the domain will be wrapped in a gymnasium environment exposing `action_masks()`. kwargs: keyword arguments passed to the algo_class constructor. """ @@ -109,6 +114,7 @@ def __init__( self._learn_config = learn_config if learn_config is not None else {} self._algo_kwargs = kwargs self.callback = callback + self.use_action_masking = use_action_masking # Handle kwargs (potentially generated by optuna) if "total_timesteps" in kwargs: @@ -131,6 +137,12 @@ def __init__( ent_coef_log = kwargs.pop("ent_coef_log") kwargs["ent_coef"] = 10**ent_coef_log + def _as_gymnasium_env(self, domain: Domain) -> gym.Env: + if self.use_action_masking: + return as_masked_gymnasium_env(domain) + else: + return as_gymnasium_env(domain) + @classmethod def _check_domain_additional(cls, domain: Domain) -> bool: return isinstance(domain.get_action_space(), GymSpace) and isinstance( @@ -145,7 +157,7 @@ def _solve(self) -> None: self, "_algo" ): # reuse algo if possible (enables further learning) domain = self._domain_factory() - env = AsGymnasiumEnv(domain) # we let the algo wrap it in a vectorized env + env = self._as_gymnasium_env(domain) self._algo = self._algo_class( self._baselines_policy, env, **self._algo_kwargs ) @@ -168,7 +180,7 @@ def _solve(self) -> None: def _sample_action( self, observation: D.T_agent[D.T_observation], **kwargs: Any ) -> D.T_agent[D.T_concurrency[D.T_event]]: - action, _ = self._algo.predict(self._unwrap_obs(observation)) + action, _ = self._algo.predict(self._unwrap_obs(observation), **kwargs) return self._wrap_action(action) def _is_policy_defined_for(self, observation: D.T_agent[D.T_observation]) -> bool: @@ -179,7 +191,7 @@ def _save(self, path: str) -> None: def _load(self, path: str): domain = self._domain_factory() - self._algo = self._algo_class.load(path, env=AsGymnasiumEnv(domain)) + self._algo = self._algo_class.load(path, env=self._as_gymnasium_env(domain)) self._init_algo(domain) def _init_algo(self, domain: D): @@ -205,3 +217,40 @@ def __init__( def _on_step(self) -> bool: return not self.callback(self.solver) + + +def as_gymnasium_env(domain: Domain) -> gym.Env: + """Wraps the domain into a gymnasium env. + + To be fed to sb3 algorithms. + + """ + return AsGymnasiumEnv(domain=domain) + + +def as_masked_gymnasium_env(domain: Domain) -> gym.Env: + """Wraps the domain into an action-masked gymnasium env. + + This means that it exposes a method `self.action_masks()` as expected by algorithms like + `sb3_contrib.MaskablePPO`. + + Uses `domain.action_masks()` when existing, else tries to derive one by using `domain.is_applicable_action()` + provided that `domain.get_action_space()` is a `skdecide.core.EnumerableSpace`. + + For computational efficiency, it is generally better to have properly implemented `domain.action_masks()`. + + """ + env = AsGymnasiumEnv(domain=domain) + if hasattr(domain, "action_masks"): + action_masks_fn = lambda env: env.domain.action_masks() + else: + action_masks_fn = lambda env: np.array( + [ + env.domain.is_applicable_action(action) + for action in domain.get_action_space().get_elements() + ] + ) + return ActionMasker( + env=env, + action_mask_fn=action_masks_fn, + ) diff --git a/tests/solvers/python/test_gnn_sb3.py b/tests/solvers/python/test_gnn_sb3.py index 2e4715c0b8..fc3f329a0a 100644 --- a/tests/solvers/python/test_gnn_sb3.py +++ b/tests/solvers/python/test_gnn_sb3.py @@ -2,6 +2,7 @@ from typing import Any, Optional import numpy as np +import numpy.typing as npt import pytest import torch as th import torch_geometric as thg @@ -20,6 +21,7 @@ from skdecide.hub.solver.stable_baselines.gnn.common.torch_layers import ( GraphFeaturesExtractor, ) +from skdecide.hub.solver.stable_baselines.gnn.ppo_mask import MaskableGraphPPO from skdecide.hub.space.gym import GymSpace, ListSpace from skdecide.utils import rollout @@ -62,6 +64,9 @@ def _state_step( outcome.state = self._np_state2graph_state(outcome.state) return outcome + def action_masks(self) -> npt.NDArray[bool]: + return np.array(self._gym_env.valid_action_mask()) + def _get_applicable_actions_from( self, memory: D.T_memory[D.T_state] ) -> D.T_agent[Space[D.T_event]]: @@ -279,6 +284,18 @@ def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any: maze_memory = self._graph2mazestate(memory) self.maze_domain._render_from(memory=maze_memory, **kwargs) + def action_masks(self) -> npt.NDArray[bool]: + mazestate_memory = self._graph2mazestate(self._memory) + return np.array( + [ + self._graph2mazestate( + self._get_next_state(action=action, memory=self._memory) + ) + != mazestate_memory + for action in self._get_action_space().get_elements() + ] + ) + discrete_features = param_fixture("discrete_features", [False, True]) @@ -415,3 +432,23 @@ def test_ppo_user_reduction_layer(domain_factory): num_episodes=1, render=False, ) + + +def test_maskable_ppo(domain_factory): + with StableBaseline( + domain_factory=domain_factory, + algo_class=MaskableGraphPPO, + baselines_policy="GraphInputPolicy", + learn_config={"total_timesteps": 100}, + use_action_masking=True, + ) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + use_action_masking=True, + )