Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add maskable GraphPPO based on sb3_contrib.MaskablePPO + GNN for domains with graph observations #444

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

nhuet
Copy link
Contributor

@nhuet nhuet commented Nov 29, 2024

  • Derive MaskableGraphPPO from MaskablePPO (and also GraphOnPolicyAlgorithm to reuse work done for GraphPPO)
  • Derive MaskableGNNActorCriticPolicy from MaskableActorCriticPolicy (and similarly from _BaseGNNActorCriticPolicy to share code with GNNActorCriticPolicy), idem for MaskableMultiInputGNNActorCriticPolicy.
  • Update stable_baselines3 scikit-decide wrapper with a new argument use_action_masking in order to wrap the domain in an environment exposing action_masks method (redirecting towards domain.get_action_mask())
  • We add examples for GraphMaze and GraphJspDomain so that the solver propose only applicable actions

@nhuet nhuet force-pushed the gnn-sb3-maskable branch 2 times, most recently from f52083e to 9550a07 Compare December 5, 2024 09:20
- Use it in rollout to make them be aware of current action mask.
- Add a `get_action_mask()` method to domains by default converting
  applicable actions space into a 0-1 numpy array, provided that the
  action space of each agent is an EnumerableSpace.
- inherits from Maskable
- do not require anymore FullObservable from the domain to use action
  masking, as get_action_mask() can be called without the solver knowing about
  the current state (and since in rollout, the actual domain is now
  used)
- decide whether using action masking directly in __init__() so that
  using_applicable_actions() can be overriden properly
- use common functions for unwrap_obs and wrap_action in solver and
  wrapper environment to avoid code duplication
- use domain.get_action_mask() to convert applicable actions into a mask
  (the method is more efficient as not calling get_applicable_actions()
  for each actions)
This is more memory sufficient for only 0-1's.
And seems to be the standard for action mask at least for ray.rllib,
as shown in `action_mask_key` documentation at
https://docs.ray.io/en/latest/rllib/rllib-training.html
- we reuse our stable_baselines3 wrapper
- the policy is extracting features from the graph with a GNN
- the GNN is using pytorch-geometric
- We subclass
  - ActorCriticPolicy:
    - feature extractor = gnn
    - custom conversion of observation to torch to convert into
      torch_geometric.data.Data
  - PPO to handle properly
    - observation conversion
    - rollout buffer
- Current limitations:
  - we extract a fixed number of features (independent of edge/node
    numbers) for now as we end with a feature reduction layer connected
    to a classic mlp (not knowning anything about the current graph structure)
- User input: the user can define (and default choices are made else)
  - the gnn (default to a 2 layers GCN), taking as inputs w.r.t torch_geometric conventions:
    - x: nodes features
    - edge_index: edge indices or sparse transposed adjency matrix
    - edge_attr (optional): edges features
    - edge_weight (optional): edge weights (taken from first dimension
      of edge_attr)
  - the feature reduction layer from the gnn output to the fixed number of features
    (default to global_max_pool + linear layer + relu)

We also introduce a multiinput policy to take into account static graph
features. The observation space is a DictSpace whose subspaces can
contain some Graph spaces.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant