Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a characteristic for solvers using action masks and make use of it in rollout #445

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion skdecide/builders/domain/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import functools
from typing import Optional, Union

from skdecide.core import D, EmptySpace, Space, autocastable
import numpy as np

from skdecide.core import D, EmptySpace, EnumerableSpace, Mask, Space, autocastable

__all__ = ["Events", "Actions", "UnrestrictedActions"]

Expand Down Expand Up @@ -326,6 +328,77 @@ def _is_applicable_action_from(
else: # StrDict
return all(applicable_actions[k].contains(v) for k, v in action.items())

@autocastable
def get_action_mask(
self, memory: Optional[D.T_memory[D.T_state]] = None
) -> D.T_agent[Mask]:
"""Get action mask for the given memory or internal one if omitted.

An action mask is another (more specific) format for applicable actions, that has a meaning only if the action
space can be iterated over in some way. It is represented by a flat array of 0's and 1's ordered as the actions
when enumerated: 1 for an applicable action, and 0 for a not applicable action.

More precisely, this implementation makes the assumption that each agent action space is an `EnumerableSpace`,
and calls internally `self.get_applicable_action()`.

The action mask is used for instance by RL solvers to shut down logits associated to non-applicable actions in
the output of their internal neural network.

# Parameters
memory: The memory to consider. If None, works on the internal memory of the domain.

# Returns
a numpy array (or dict agent-> numpy array for multi-agent domains) with 0-1 indicating applicability of
the action (1 meaning applicable and 0 not applicable)
"""
return self._get_action_mask(memory=memory)

def _get_action_mask(
self, memory: Optional[D.T_memory[D.T_state]] = None
) -> D.T_agent[Mask]:
"""Get action mask for the given memory or internal one if omitted.

An action mask is another (more specific) format for applicable actions, that has a meaning only if the action
space can be iterated over in some way. It is represented by a flat array of 0's and 1's ordered as the actions
when enumerated: 1 for an applicable action, and 0 for a not applicable action.

More precisely, this implementation makes the assumption that each agent action space is an `EnumerableSpace`,
and calls internally `self.get_applicable_action()`.

The action mask is used for instance by RL solvers to shut down logits associated to non-applicable actions in
the output of their internal neural network.

# Parameters
memory: The memory to consider. If None, works on the internal memory of the domain.

# Returns
a numpy array (or dict agent-> numpy array for multi-agent domains) with 0-1 indicating applicability of
the action (1 meaning applicable and 0 not applicable)
"""
applicable_actions = self._get_applicable_actions(memory=memory)
action_space = self._get_action_space()
if self.T_agent == Union:
# single agent
return np.array(
[
1 if applicable_actions.contains(a) else 0
for a in action_space.get_elements()
],
dtype=np.int8,
)
else:
# multi agent
return {
agent: np.array(
[
1 if agent_applicable_actions.contains(a) else 0
for a in action_space[agent].get_elements()
],
dtype=np.int8,
)
for agent, agent_applicable_actions in applicable_actions.items()
}


class Actions(Events):
"""A domain must inherit this class if it handles only actions (i.e. controllable events)."""
Expand Down
1 change: 1 addition & 0 deletions skdecide/builders/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from skdecide.builders.solver.assessability import *
from skdecide.builders.solver.fromanystatesolvability import *
from skdecide.builders.solver.maskability import *
from skdecide.builders.solver.parallelability import *
from skdecide.builders.solver.policy import *
from skdecide.builders.solver.restorability import *
97 changes: 97 additions & 0 deletions skdecide/builders/solver/maskability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

from skdecide import D, autocastable
from skdecide.core import Mask

if TYPE_CHECKING:
# avoid circular import
from skdecide import Domain


__all__ = ["ApplicableActions", "Maskable"]


class ApplicableActions:
"""A solver must inherit this class if he can use information about applicable action.

This characteristic will be checked during rollout so that `retrieve_applicable_actions()` will be called before
each call to `step()`. For instance, this is the case for solvers using action masks (see `Maskable`).

"""

def using_applicable_actions(self):
"""Tell if the solver is able to use applicable actions information.

For instance, action masking could be possible only if
considered domain action space is enumerable for each agent.

The default implementation returns always True.

"""
return True

def retrieve_applicable_actions(self, domain: Domain) -> None:
"""Retrieve applicable actions and use it for future call to `self.step()`.

To be called during rollout to get the actual applicable actions from the actual domain used in rollout.

"""
raise NotImplementedError


class Maskable(ApplicableActions):
"""A solver must inherit this class if he can use action masks to sample actions.

For instance, it can be the case for wrappers around RL solvers like `sb3_contrib.MaskablePPO` or `ray.rllib` with
custom model making use of action masking.

An action mask is a format for specifying applicable actions when the action space is enumerable and finite. It is
an array with 0's (for non-applicable actions) and 1's (for applicable actions). See `Events.get_action_mask()` for
more information.

"""

_action_mask: Optional[D.T_agent[Mask]] = None

def retrieve_applicable_actions(self, domain: Domain) -> None:
"""Retrieve applicable actions and use it for future call to `self.step()`.

To be called during rollout to get the actual applicable actions from the actual domain used in rollout.
Transform applicable actions into an action_mask to be use when sampling action.

"""
self.set_action_mask(domain.get_action_mask())

@autocastable
def set_action_mask(self, action_mask: Optional[D.T_agent[Mask]]) -> None:
"""Set the action mask.

To be called during rollout before `self.sample_action()`, assuming that
`self.sample_action()` knows what to do with it.

Autocastable so that it can use action_mask from original domain during rollout.

"""
self._set_action_mask(action_mask=action_mask)

def _set_action_mask(self, action_mask: Optional[D.T_agent[Mask]]) -> None:
"""Set the action mask.

To be called during rollout before `self.sample_action()`, assuming that
`self.sample_action()` knows what to do with it.


"""

self._action_mask = action_mask

def get_action_mask(self) -> Optional[D.T_agent[Mask]]:
"""Retrieve stored action masks.

To be used by `self.sample_action()`.
Returns None if `self.set_action_mask()` was not called.

"""
return self._action_mask
10 changes: 10 additions & 0 deletions skdecide/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from dataclasses import asdict, astuple, dataclass, replace
from typing import Generic, Optional, TypeVar, Union

import numpy as np
import numpy.typing as npt

__all__ = [
"T",
"D",
Expand Down Expand Up @@ -666,6 +669,13 @@ def cast_evaluate_function(memory, action, next_state):
)


# The following alias is needed in core module so that autocast works:
# - `autocast` does not like "." after strings other than "D",
# - `autocast` needs types in annotations to be evaluable in `skdecide.core` namespace.
Mask = npt.NDArray[np.int8]
"""Alias for single agent action mask."""


SINGLE_AGENT_ID = "agent"

# (auto)cast-related objects/functions
Expand Down
Loading
Loading