Skip to content

Commit

Permalink
Add a sb3 algo + policy for domains with graph observations
Browse files Browse the repository at this point in the history
- we reuse our stable_baselines3 wrapper
- the policy is extracting features from the graph with a GNN
- the GNN is using pytorch-geometric
- We subclass
  - ActorCriticPolicy:
    - feature extractor = gnn
    - custom conversion of observation to torch to convert into
      torch_geometric.data.Data
  - PPO to handle properly
    - observation conversion
    - rollout buffer
- Current limitations:
  - we extract a fixed number of features (independent of edge/node
    numbers) for now as we end with a feature reduction layer connected
    to a classic mlp (not knowning anything about the current graph structure)
- User input: the user can define (and default choices are made else)
  - the gnn (default to a 2 layers GCN), taking as inputs w.r.t torch_geometric conventions:
    - x: nodes features
    - edge_index: edge indices or sparse transposed adjency matrix
    - edge_attr (optional): edges features
    - edge_weight (optional): edge weights (taken from first dimension
      of edge_attr)
  - the feature reduction layer from the gnn output to the fixed number of features
    (default to global_max_pool + linear layer + relu)
  • Loading branch information
nhuet committed Nov 21, 2024
1 parent e5e4a19 commit 772a739
Show file tree
Hide file tree
Showing 17 changed files with 1,660 additions and 24 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -564,9 +564,9 @@ jobs:
python_version=${{ matrix.python-version }}
wheelfile=$(ls ./wheels/scikit_decide*-cp${python_version/\./}-*win*.whl)
if [ "$python_version" = "3.12" ]; then
pip install ${wheelfile}[all] pytest "pygame>=2.5" optuna "cffi>=1.17"
pip install ${wheelfile}[all] pytest "pygame>=2.5" optuna "cffi>=1.17" graph-jsp-env
else
pip install ${wheelfile}[all] pytest gymnasium[classic-control] optuna
pip install ${wheelfile}[all] pytest gymnasium[classic-control] optuna graph-jsp-env
fi
- name: Test with pytest
Expand Down Expand Up @@ -662,9 +662,9 @@ jobs:
arch=$(uname -m)
wheelfile=$(ls ./wheels/scikit_decide*-cp${python_version/\./}-*macos*${arch}.whl)
if [ "$python_version" = "3.12" ]; then
pip install ${wheelfile}[all] pytest "pygame>=2.5" optuna "cffi>=1.17"
pip install ${wheelfile}[all] pytest "pygame>=2.5" optuna "cffi>=1.17" graph-jsp-env
else
pip install ${wheelfile}[all] pytest gymnasium[classic-control] optuna
pip install ${wheelfile}[all] pytest gymnasium[classic-control] optuna graph-jsp-env
fi
- name: Test with pytest
Expand Down Expand Up @@ -762,9 +762,9 @@ jobs:
python_version=${{ matrix.python-version }}
wheelfile=$(ls ./wheels/scikit_decide*-cp${python_version/\./}-*manylinux*.whl)
if [ "$python_version" = "3.12" ]; then
pip install ${wheelfile}[all] pytest "pygame>=2.5" "cffi>=1.17" docopt commonmark optuna
pip install ${wheelfile}[all] pytest "pygame>=2.5" "cffi>=1.17" docopt commonmark optuna graph-jsp-env
else
pip install ${wheelfile}[all] pytest gymnasium[classic-control] docopt commonmark optuna
pip install ${wheelfile}[all] pytest gymnasium[classic-control] docopt commonmark optuna graph-jsp-env
fi
- name: Test with pytest
Expand Down
140 changes: 140 additions & 0 deletions examples/gnn_sb3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import Any

import numpy as np
from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
from gymnasium.spaces import Box, Graph, GraphInstance

from skdecide.core import Space, TransitionOutcome, Value
from skdecide.domains import Domain
from skdecide.hub.domain.gym import GymDomain
from skdecide.hub.solver.stable_baselines import StableBaseline
from skdecide.hub.solver.stable_baselines.gnn import GraphPPO
from skdecide.hub.space.gym import GymSpace, ListSpace
from skdecide.utils import rollout

# JSP graph env


class D(Domain):
T_state = GraphInstance # Type of states
T_observation = T_state # Type of observations
T_event = int # Type of events
T_value = float # Type of transition values (rewards or costs)
T_info = None # Type of additional information in environment outcome


class GraphJspDomain(GymDomain, D):
_gym_env: DisjunctiveGraphJspEnv

def __init__(self, gym_env):
GymDomain.__init__(self, gym_env=gym_env)
if self._gym_env.normalize_observation_space:
self.n_nodes_features = gym_env.n_machines + 1
else:
self.n_nodes_features = 2

def _state_step(
self, action: D.T_event
) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]:
outcome = super()._state_step(action=action)
outcome.state = self._np_state2graph_state(outcome.state)
return outcome

def _get_applicable_actions_from(
self, memory: D.T_memory[D.T_state]
) -> D.T_agent[Space[D.T_event]]:
return ListSpace(np.nonzero(self._gym_env.valid_action_mask())[0])

def _is_applicable_action_from(
self, action: D.T_agent[D.T_event], memory: D.T_memory[D.T_state]
) -> bool:
return self._gym_env.valid_action_mask()[action]

def _state_reset(self) -> D.T_state:
return self._np_state2graph_state(super()._state_reset())

def _get_observation_space_(self) -> Space[D.T_observation]:
if self._gym_env.normalize_observation_space:
original_graph_space = Graph(
node_space=Box(
low=0.0, high=1.0, shape=(self.n_nodes_features,), dtype=np.float_
),
edge_space=Box(low=0, high=1.0, dtype=np.float_),
)

else:
original_graph_space = Graph(
node_space=Box(
low=np.array([0, 0]),
high=np.array(
[
self._gym_env.n_machines,
self._gym_env.longest_processing_time,
]
),
dtype=np.int_,
),
edge_space=Box(
low=0, high=self._gym_env.longest_processing_time, dtype=np.int_
),
)
return GymSpace(original_graph_space)

def _np_state2graph_state(self, np_state: np.array) -> GraphInstance:
if not self._gym_env.normalize_observation_space:
np_state = np_state.astype(np.int_)

nodes = np_state[:, -self.n_nodes_features :]
adj = np_state[:, : -self.n_nodes_features]
edge_starts_ends = adj.nonzero()
edge_links = np.transpose(edge_starts_ends)
edges = adj[edge_starts_ends][:, None]

return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links)

def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:
return self._gym_env.render(**kwargs)


jsp = np.array(
[
[
[0, 1, 2], # machines for job 0
[0, 2, 1], # machines for job 1
[0, 1, 2], # machines for job 2
],
[
[3, 2, 2], # task durations of job 0
[2, 1, 4], # task durations of job 1
[0, 4, 3], # task durations of job 2
],
]
)


jsp_env = DisjunctiveGraphJspEnv(
jps_instance=jsp,
perform_left_shift_if_possible=True,
normalize_observation_space=False,
flat_observation_space=False,
action_mode="task",
)


# random rollout
domain = GraphJspDomain(gym_env=jsp_env)
rollout(domain=domain, max_steps=jsp_env.total_tasks_without_dummies, num_episodes=1)

# solve with sb3-PPO-GNN
domain_factory = lambda: GraphJspDomain(gym_env=jsp_env)
with StableBaseline(
domain_factory=domain_factory,
algo_class=GraphPPO,
baselines_policy="GraphInputPolicy",
learn_config={"total_timesteps": 100},
# batch_size=1,
# normalize_advantage=False
) as solver:

solver.solve()
rollout(domain=domain_factory(), solver=solver, max_steps=100, num_episodes=1)
Loading

0 comments on commit 772a739

Please sign in to comment.