-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RDDLDomain, RDDLJaxSolver, and RDDLGurobiSolver to hub as present…
…ed in ICAPS 2024 tutorial
- Loading branch information
Showing
14 changed files
with
2,132 additions
and
52 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "scikit-decide" | ||
version = "0.0.0" # placeholder for poetry-dynamic-versioning | ||
version = "1.0.3.dev5+90a1f41c" # placeholder for poetry-dynamic-versioning | ||
description = "The AI framework for Reinforcement Learning, Automated Planning and Scheduling" | ||
authors = ["Airbus AI Research <[email protected]>"] | ||
license = "MIT" | ||
|
@@ -34,7 +34,7 @@ script = "builder.py" | |
generate-setup-file = true | ||
|
||
[tool.poetry-dynamic-versioning] | ||
enable = true | ||
enable = false | ||
vcs = "git" | ||
format-jinja = """ | ||
{%- if distance == 0 -%} | ||
|
@@ -71,6 +71,10 @@ pygrib = [ | |
{ version = "<=2.1.5", platform = "linux", optional = true }, | ||
{ version = ">=2.1.5", platform = "darwin", optional = true }, | ||
] | ||
pyRDDLGym = { version = ">=2.0, <2.1", optional = true } | ||
pyRDDLGym-rl = { version = ">=0.1", optional = true } | ||
pyRDDLGym-jax = { version = ">=0.3", optional = true } | ||
rddlrepository = {version = ">=2.0", optional = true } | ||
|
||
[tool.poetry.extras] | ||
domains = [ | ||
|
@@ -82,7 +86,10 @@ domains = [ | |
"unified-planning", | ||
"cartopy", | ||
"pygrib", | ||
"scipy" | ||
"scipy", | ||
"pyRDDLGym", | ||
"pyRDDLGym-rl", | ||
"rddlrepository" | ||
] | ||
solvers = [ | ||
"gymnasium", | ||
|
@@ -95,7 +102,8 @@ solvers = [ | |
"up-fast-downward", | ||
"up-enhsp", | ||
"up-pyperplan", | ||
"scipy" | ||
"scipy", | ||
"pyRDDLGym-jax" | ||
] | ||
all = [ | ||
"gymnasium", | ||
|
@@ -113,7 +121,11 @@ all = [ | |
"up-pyperplan", | ||
"cartopy", | ||
"pygrib", | ||
"scipy" | ||
"scipy", | ||
"pyRDDLGym", | ||
"pyRDDLGym-rl", | ||
"rddlrepository", | ||
"pyRDDLGym-jax" | ||
] | ||
|
||
[tool.poetry.plugins."skdecide.domains"] | ||
|
@@ -136,6 +148,9 @@ Stochastic_RCPSP = "skdecide.hub.domain.rcpsp:Stochastic_RCPSP [domains]" | |
SMRCPSPCalendar = "skdecide.hub.domain.rcpsp:SMRCPSPCalendar [domains]" | ||
MSRCPSP = "skdecide.hub.domain.rcpsp:MSRCPSP [domains]" | ||
MSRCPSPCalendar = "skdecide.hub.domain.rcpsp:MSRCPSPCalendar [domains]" | ||
RDDLDomain = "skdecide.hub.domain.rddl:RDDLDomain [domains]" | ||
RDDLDomainRL = "skdecide.hub.domain.rddl:RDDLDomainRL [domains]" | ||
RDDLDomainSimplifiedSpaces = "skdecide.hub.domain.rddl:RDDLDomainSimplifiedSpaces [domains]" | ||
|
||
[tool.poetry.plugins."skdecide.solvers"] | ||
AOstar = "skdecide.hub.solver.aostar:AOstar [solvers]" | ||
|
@@ -162,6 +177,8 @@ DOSolver = "skdecide.hub.solver.do_solver:DOSolver [solvers]" | |
GPHH = "skdecide.hub.solver.do_solver:GPHH [solvers]" | ||
PilePolicy = "skdecide.hub.solver.pile_policy_scheduling:PilePolicy [solvers]" | ||
UPSolver = "skdecide.hub.solver.up:UPSolver [solvers]" | ||
RDDLJaxSolver = "skdecide.hub.solver.rddl:RDDLJaxSolver [solvers]" | ||
RDDLGurobiSolver = "skdecide.hub.solver.rddl:RDDLGurobiSolver [solvers]" | ||
|
||
[tool.poetry.dev-dependencies] | ||
pytest = "^6.2.2" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .rddl import RDDLDomain, RDDLDomainRL, RDDLDomainSimplifiedSpaces |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import os | ||
import shutil | ||
from datetime import datetime | ||
from typing import Any | ||
|
||
import numpy as np | ||
import pyRDDLGym | ||
from gymnasium.spaces.utils import flatten, flatten_space | ||
from pyRDDLGym import RDDLEnv | ||
from pyRDDLGym.core.simulator import RDDLSimulator | ||
from pyRDDLGym.core.visualizer.chart import ChartVisualizer | ||
from pyRDDLGym.core.visualizer.movie import MovieGenerator | ||
from pyRDDLGym.core.visualizer.viz import BaseViz | ||
from pyRDDLGym_rl.core.env import SimplifiedActionRDDLEnv | ||
|
||
from skdecide.builders.domain import FullyObservable, Renderable, UnrestrictedActions | ||
from skdecide.core import Space, TransitionOutcome, Value | ||
from skdecide.domains import RLDomain | ||
from skdecide.hub.space.gym import GymSpace | ||
|
||
try: | ||
import IPython | ||
except ImportError: | ||
ipython_available = False | ||
else: | ||
ipython_available = True | ||
from IPython.display import clear_output, display | ||
|
||
|
||
class D(RLDomain, UnrestrictedActions, FullyObservable, Renderable): | ||
T_state = dict[str, Any] # Type of states | ||
T_observation = T_state # Type of observations | ||
T_event = np.array # Type of events | ||
T_value = float # Type of transition values (rewards or costs) | ||
T_info = None # Type of additional information in environment outcome | ||
|
||
|
||
class RDDLDomain(D): | ||
def __init__( | ||
self, | ||
rddl_domain: str, | ||
rddl_instance: str, | ||
base_class: type[RDDLEnv] = RDDLEnv, | ||
backend: type[RDDLSimulator] = RDDLSimulator, | ||
display_with_pygame: bool = True, | ||
display_within_jupyter: bool = False, | ||
visualizer: BaseViz = ChartVisualizer, | ||
movie_name: str = None, | ||
movie_dir: str = "rddl_movies", | ||
max_frames=1000, | ||
enforce_action_constraints=True, | ||
**kwargs | ||
): | ||
self.rddl_gym_env = pyRDDLGym.make( | ||
rddl_domain, | ||
rddl_instance, | ||
base_class=base_class, | ||
backend=backend, | ||
enforce_action_constraints=enforce_action_constraints, | ||
**kwargs | ||
) | ||
self.display_within_jupyter = display_within_jupyter | ||
self.display_with_pygame = display_with_pygame | ||
self.movie_name = movie_name | ||
self._nb_step = 0 | ||
if movie_name is not None: | ||
self.movie_path = os.path.join(movie_dir, movie_name) | ||
if not os.path.exists(self.movie_path): | ||
os.makedirs(self.movie_path) | ||
tmp_pngs = os.path.join(self.movie_path, "tmp_pngs") | ||
if os.path.exists(tmp_pngs): | ||
shutil.rmtree(tmp_pngs) | ||
os.makedirs(tmp_pngs) | ||
self.movie_gen = MovieGenerator(tmp_pngs, movie_name, max_frames=max_frames) | ||
self.rddl_gym_env.set_visualizer(visualizer, self.movie_gen) | ||
else: | ||
self.movie_gen = None | ||
self.rddl_gym_env.set_visualizer(visualizer) | ||
|
||
def _state_step( | ||
self, action: D.T_event | ||
) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]: | ||
next_state, reward, terminated, truncated, _ = self.rddl_gym_env.step(action) | ||
termination = terminated or truncated | ||
if self.movie_gen is not None and ( | ||
termination or self._nb_step >= self.movie_gen.max_frames - 1 | ||
): | ||
self.movie_gen.save_animation(self.movie_name) | ||
tmp_pngs = os.path.join(self.movie_path, "tmp_pngs") | ||
shutil.move( | ||
os.path.join(tmp_pngs, self.movie_name + ".gif"), | ||
os.path.join( | ||
self.movie_path, | ||
self.movie_name | ||
+ "_" | ||
+ str(datetime.now().strftime("%Y%m%d-%H%M%S")) | ||
+ ".gif", | ||
), | ||
) | ||
self._nb_step += 1 | ||
# TransitionOutcome and Value are scikit-decide types | ||
return TransitionOutcome( | ||
state=next_state, value=Value(reward=reward), termination=termination | ||
) | ||
|
||
def _get_action_space_(self) -> Space[D.T_event]: | ||
# Cast to skdecide's GymSpace | ||
return GymSpace(self.rddl_gym_env.action_space) | ||
|
||
def _state_reset(self) -> D.T_state: | ||
self._nb_step = 0 | ||
# SkDecide only needs the state, not the info | ||
return self.rddl_gym_env.reset()[0] | ||
|
||
def _get_observation_space_(self) -> Space[D.T_observation]: | ||
# Cast to skdecide's GymSpace | ||
return GymSpace(self.rddl_gym_env.observation_space) | ||
|
||
def _render_from(self, memory: D.T_state = None, **kwargs: Any) -> Any: | ||
# We do not want the image to be displayed in a pygame window, but rather in this notebook | ||
rddl_gym_img = self.rddl_gym_env.render(to_display=self.display_with_pygame) | ||
if self.display_within_jupyter and ipython_available: | ||
clear_output(wait=True) | ||
display(rddl_gym_img) | ||
return rddl_gym_img | ||
|
||
|
||
class RDDLDomainRL(RDDLDomain): | ||
def __init__( | ||
self, | ||
rddl_domain: str, | ||
rddl_instance: str, | ||
base_class: type[RDDLEnv] = SimplifiedActionRDDLEnv, | ||
backend: type[RDDLSimulator] = RDDLSimulator, | ||
display_with_pygame: bool = True, | ||
display_within_jupyter: bool = False, | ||
visualizer: BaseViz = ChartVisualizer, | ||
movie_name: str = None, | ||
movie_dir: str = "rddl_movies", | ||
max_frames=1000, | ||
enforce_action_constraints=True, | ||
**kwargs | ||
): | ||
super().__init__( | ||
rddl_domain=rddl_domain, | ||
rddl_instance=rddl_instance, | ||
base_class=base_class, | ||
backend=backend, | ||
display_with_pygame=display_with_pygame, | ||
display_within_jupyter=display_within_jupyter, | ||
visualizer=visualizer, | ||
movie_name=movie_name, | ||
movie_dir=movie_dir, | ||
max_frames=max_frames, | ||
enforce_action_constraints=enforce_action_constraints, | ||
**kwargs | ||
) | ||
|
||
|
||
class RDDLDomainSimplifiedSpaces(RDDLDomainRL): | ||
def _state_step( | ||
self, action: D.T_event | ||
) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]: | ||
outcome = super()._state_step(action) | ||
return TransitionOutcome( | ||
state=flatten(self.rddl_gym_env.observation_space, outcome.state), | ||
value=outcome.value, | ||
termination=outcome.termination, | ||
) | ||
|
||
def _get_action_space_(self) -> Space[D.T_event]: | ||
return GymSpace(flatten_space(self.rddl_gym_env.action_space)) | ||
|
||
def _state_reset(self) -> D.T_state: | ||
# SkDecide only needs the state, not the info | ||
return flatten(self.rddl_gym_env.observation_space, super()._state_reset()) | ||
|
||
def _get_observation_space_(self) -> Space[D.T_observation]: | ||
return GymSpace(flatten_space(self.rddl_gym_env.observation_space)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .rddl import RDDLGurobiSolver, RDDLJaxSolver |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from collections.abc import Callable | ||
from typing import Any, Optional | ||
|
||
from pyRDDLGym_jax.core.planner import ( | ||
JaxBackpropPlanner, | ||
JaxOfflineController, | ||
JaxOnlineController, | ||
load_config, | ||
) | ||
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator | ||
|
||
from skdecide import Solver | ||
from skdecide.builders.solver import FromInitialState, Policies | ||
from skdecide.hub.domain.rddl import RDDLDomain | ||
|
||
try: | ||
from pyRDDLGym_gurobi.core.planner import ( | ||
GurobiOnlineController, | ||
GurobiPlan, | ||
GurobiStraightLinePlan, | ||
) | ||
except ImportError: | ||
pyrddlgym_gurobi_available = False | ||
else: | ||
pyrddlgym_gurobi_available = True | ||
|
||
|
||
class D(RDDLDomain): | ||
pass | ||
|
||
|
||
class RDDLJaxSolver(Solver, Policies, FromInitialState): | ||
T_domain = D | ||
|
||
def __init__( | ||
self, domain_factory: Callable[[], RDDLDomain], config: Optional[str] = None | ||
): | ||
Solver.__init__(self, domain_factory=domain_factory) | ||
self._domain = domain_factory() | ||
if config is not None: | ||
self.planner_args, _, self.train_args = load_config(config) | ||
|
||
@classmethod | ||
def _check_domain_additional(cls, domain: D) -> bool: | ||
return hasattr(domain, "rddl_gym_env") | ||
|
||
def _solve(self, from_memory: Optional[D.T_state] = None) -> None: | ||
planner = JaxBackpropPlanner( | ||
rddl=self._domain.rddl_gym_env.model, | ||
**(self.planner_args if self.planner_args is not None else {}) | ||
) | ||
self.controller = JaxOfflineController( | ||
planner, **(self.train_args if self.train_args is not None else {}) | ||
) | ||
|
||
def _sample_action(self, observation: D.T_observation) -> D.T_event: | ||
return self.controller.sample_action(observation) | ||
|
||
def _is_policy_defined_for(self, observation: D.T_observation) -> bool: | ||
return True | ||
|
||
|
||
if pyrddlgym_gurobi_available: | ||
|
||
class D(RDDLDomain): | ||
pass | ||
|
||
class RDDLGurobiSolver(Solver, Policies, FromInitialState): | ||
T_domain = D | ||
|
||
def __init__( | ||
self, | ||
domain_factory: Callable[[], RDDLDomain], | ||
plan: Optional[GurobiPlan] = None, | ||
rollout_horizon=5, | ||
model_params: Optional[dict[str, Any]] = None, | ||
): | ||
Solver.__init__(self, domain_factory=domain_factory) | ||
self._domain = domain_factory() | ||
self.rollout_horizon = rollout_horizon | ||
if plan is None: | ||
self.plan = GurobiStraightLinePlan() | ||
else: | ||
self.plan = plan | ||
if model_params is None: | ||
self.model_params = {"NonConvex": 2, "OutputFlag": 0} | ||
else: | ||
self.model_params = model_params | ||
|
||
@classmethod | ||
def _check_domain_additional(cls, domain: D) -> bool: | ||
return hasattr(domain, "rddl_gym_env") | ||
|
||
def _solve(self, from_memory: Optional[D.T_state] = None) -> None: | ||
self.controller = GurobiOnlineController( | ||
rddl=self._domain.rddl_gym_env.model, | ||
plan=self.plan, | ||
rollout_horizon=self.rollout_horizon, | ||
model_params=self.model_params, | ||
) | ||
|
||
def _sample_action(self, observation: D.T_observation) -> D.T_event: | ||
return self.controller.sample_action(observation) | ||
|
||
def _is_policy_defined_for(self, observation: D.T_observation) -> bool: | ||
return True | ||
|
||
else: | ||
|
||
class RDDLGurobiSolver(Solver, Policies, FromInitialState): | ||
T_domain = D | ||
|
||
def __init__(self, domain_factory: Callable[[], RDDLDomain], rollout_horizon=5): | ||
raise RuntimeError( | ||
"You need pyRDDLGym-gurobi installed for this solver. " | ||
"See https://github.com/pyrddlgym-project/pyRDDLGym-gurobi for more information." | ||
) |
Oops, something went wrong.