Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to GraphPPO for multi inputs with Dict spaces (including graphs) #446

Closed
wants to merge 6 commits into from

Conversation

nhuet
Copy link
Contributor

@nhuet nhuet commented Dec 3, 2024

No description provided.

nhuet added 4 commits December 5, 2024 10:18
- we reuse our stable_baselines3 wrapper
- the policy is extracting features from the graph with a GNN
- the GNN is using pytorch-geometric
- We subclass
  - ActorCriticPolicy:
    - feature extractor = gnn
    - custom conversion of observation to torch to convert into
      torch_geometric.data.Data
  - PPO to handle properly
    - observation conversion
    - rollout buffer
- Current limitations:
  - we extract a fixed number of features (independent of edge/node
    numbers) for now as we end with a feature reduction layer connected
    to a classic mlp (not knowning anything about the current graph structure)
- User input: the user can define (and default choices are made else)
  - the gnn (default to a 2 layers GCN), taking as inputs w.r.t torch_geometric conventions:
    - x: nodes features
    - edge_index: edge indices or sparse transposed adjency matrix
    - edge_attr (optional): edges features
    - edge_weight (optional): edge weights (taken from first dimension
      of edge_attr)
  - the feature reduction layer from the gnn output to the fixed number of features
    (default to global_max_pool + linear layer + relu)
- We update the signature of sample_action(), get_next_action_distribution()
  and get_next_action() with **kwargs to allow some solver to use extra
  argument like action masks.
- We add `kwargs_sample_action_fn` arg to derive generic kwargs
  to pass to `sample_action()` during rollout
- We add `use_action_masking` flag as a shortcut to define appropriate
  `kwargs_sample_action_fn` using
  - domain.action_masks() is existing
  - `domain.is_applicable_action()` provided that `domain.get_action_space()` is a `skdecide.core.EnumerableSpace`
@nhuet nhuet force-pushed the multiinput-graph-sb3 branch from 225a677 to 2c3d1ff Compare December 5, 2024 09:19
@nhuet nhuet force-pushed the multiinput-graph-sb3 branch from c46ea2a to ec4b30d Compare December 6, 2024 12:56
@nhuet
Copy link
Contributor Author

nhuet commented Dec 13, 2024

now included in #441.

@nhuet nhuet closed this Dec 13, 2024
@nhuet nhuet deleted the multiinput-graph-sb3 branch January 20, 2025 09:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant