Skip to content

Commit

Permalink
Add RDDLDomain, RDDLJaxSolver, and RDDLGurobiSolver to hub as present…
Browse files Browse the repository at this point in the history
…ed in ICAPS 2024 tutorial
  • Loading branch information
nhuet committed Oct 28, 2024
1 parent 85e5c49 commit 17987d8
Show file tree
Hide file tree
Showing 14 changed files with 2,132 additions and 52 deletions.
604 changes: 604 additions & 0 deletions notebooks/16_rddl_tuto.ipynb

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added notebooks/rddl_images/cgp-sketch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,104 changes: 1,057 additions & 47 deletions poetry.lock

Large diffs are not rendered by default.

27 changes: 22 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "scikit-decide"
version = "0.0.0" # placeholder for poetry-dynamic-versioning
version = "1.0.3.dev5+90a1f41c" # placeholder for poetry-dynamic-versioning
description = "The AI framework for Reinforcement Learning, Automated Planning and Scheduling"
authors = ["Airbus AI Research <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -34,7 +34,7 @@ script = "builder.py"
generate-setup-file = true

[tool.poetry-dynamic-versioning]
enable = true
enable = false
vcs = "git"
format-jinja = """
{%- if distance == 0 -%}
Expand Down Expand Up @@ -71,6 +71,10 @@ pygrib = [
{ version = "<=2.1.5", platform = "linux", optional = true },
{ version = ">=2.1.5", platform = "darwin", optional = true },
]
pyRDDLGym = { version = ">=2.0, <2.1", optional = true }
pyRDDLGym-rl = { version = ">=0.1", optional = true }
pyRDDLGym-jax = { version = ">=0.3", optional = true }
rddlrepository = {version = ">=2.0", optional = true }

[tool.poetry.extras]
domains = [
Expand All @@ -82,7 +86,10 @@ domains = [
"unified-planning",
"cartopy",
"pygrib",
"scipy"
"scipy",
"pyRDDLGym",
"pyRDDLGym-rl",
"rddlrepository"
]
solvers = [
"gymnasium",
Expand All @@ -95,7 +102,8 @@ solvers = [
"up-fast-downward",
"up-enhsp",
"up-pyperplan",
"scipy"
"scipy",
"pyRDDLGym-jax"
]
all = [
"gymnasium",
Expand All @@ -113,7 +121,11 @@ all = [
"up-pyperplan",
"cartopy",
"pygrib",
"scipy"
"scipy",
"pyRDDLGym",
"pyRDDLGym-rl",
"rddlrepository",
"pyRDDLGym-jax"
]

[tool.poetry.plugins."skdecide.domains"]
Expand All @@ -136,6 +148,9 @@ Stochastic_RCPSP = "skdecide.hub.domain.rcpsp:Stochastic_RCPSP [domains]"
SMRCPSPCalendar = "skdecide.hub.domain.rcpsp:SMRCPSPCalendar [domains]"
MSRCPSP = "skdecide.hub.domain.rcpsp:MSRCPSP [domains]"
MSRCPSPCalendar = "skdecide.hub.domain.rcpsp:MSRCPSPCalendar [domains]"
RDDLDomain = "skdecide.hub.domain.rddl:RDDLDomain [domains]"
RDDLDomainRL = "skdecide.hub.domain.rddl:RDDLDomainRL [domains]"
RDDLDomainSimplifiedSpaces = "skdecide.hub.domain.rddl:RDDLDomainSimplifiedSpaces [domains]"

[tool.poetry.plugins."skdecide.solvers"]
AOstar = "skdecide.hub.solver.aostar:AOstar [solvers]"
Expand All @@ -162,6 +177,8 @@ DOSolver = "skdecide.hub.solver.do_solver:DOSolver [solvers]"
GPHH = "skdecide.hub.solver.do_solver:GPHH [solvers]"
PilePolicy = "skdecide.hub.solver.pile_policy_scheduling:PilePolicy [solvers]"
UPSolver = "skdecide.hub.solver.up:UPSolver [solvers]"
RDDLJaxSolver = "skdecide.hub.solver.rddl:RDDLJaxSolver [solvers]"
RDDLGurobiSolver = "skdecide.hub.solver.rddl:RDDLGurobiSolver [solvers]"

[tool.poetry.dev-dependencies]
pytest = "^6.2.2"
Expand Down
1 change: 1 addition & 0 deletions skdecide/hub/domain/rddl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rddl import RDDLDomain, RDDLDomainRL, RDDLDomainSimplifiedSpaces
179 changes: 179 additions & 0 deletions skdecide/hub/domain/rddl/rddl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import shutil
from datetime import datetime
from typing import Any

import numpy as np
import pyRDDLGym
from gymnasium.spaces.utils import flatten, flatten_space
from pyRDDLGym import RDDLEnv
from pyRDDLGym.core.simulator import RDDLSimulator
from pyRDDLGym.core.visualizer.chart import ChartVisualizer
from pyRDDLGym.core.visualizer.movie import MovieGenerator
from pyRDDLGym.core.visualizer.viz import BaseViz
from pyRDDLGym_rl.core.env import SimplifiedActionRDDLEnv

from skdecide.builders.domain import FullyObservable, Renderable, UnrestrictedActions
from skdecide.core import Space, TransitionOutcome, Value
from skdecide.domains import RLDomain
from skdecide.hub.space.gym import GymSpace

try:
import IPython
except ImportError:
ipython_available = False
else:
ipython_available = True
from IPython.display import clear_output, display


class D(RLDomain, UnrestrictedActions, FullyObservable, Renderable):
T_state = dict[str, Any] # Type of states
T_observation = T_state # Type of observations
T_event = np.array # Type of events
T_value = float # Type of transition values (rewards or costs)
T_info = None # Type of additional information in environment outcome


class RDDLDomain(D):
def __init__(
self,
rddl_domain: str,
rddl_instance: str,
base_class: type[RDDLEnv] = RDDLEnv,
backend: type[RDDLSimulator] = RDDLSimulator,
display_with_pygame: bool = True,
display_within_jupyter: bool = False,
visualizer: BaseViz = ChartVisualizer,
movie_name: str = None,
movie_dir: str = "rddl_movies",
max_frames=1000,
enforce_action_constraints=True,
**kwargs
):
self.rddl_gym_env = pyRDDLGym.make(
rddl_domain,
rddl_instance,
base_class=base_class,
backend=backend,
enforce_action_constraints=enforce_action_constraints,
**kwargs
)
self.display_within_jupyter = display_within_jupyter
self.display_with_pygame = display_with_pygame
self.movie_name = movie_name
self._nb_step = 0
if movie_name is not None:
self.movie_path = os.path.join(movie_dir, movie_name)
if not os.path.exists(self.movie_path):
os.makedirs(self.movie_path)
tmp_pngs = os.path.join(self.movie_path, "tmp_pngs")
if os.path.exists(tmp_pngs):
shutil.rmtree(tmp_pngs)
os.makedirs(tmp_pngs)
self.movie_gen = MovieGenerator(tmp_pngs, movie_name, max_frames=max_frames)
self.rddl_gym_env.set_visualizer(visualizer, self.movie_gen)
else:
self.movie_gen = None
self.rddl_gym_env.set_visualizer(visualizer)

def _state_step(
self, action: D.T_event
) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]:
next_state, reward, terminated, truncated, _ = self.rddl_gym_env.step(action)
termination = terminated or truncated
if self.movie_gen is not None and (
termination or self._nb_step >= self.movie_gen.max_frames - 1
):
self.movie_gen.save_animation(self.movie_name)
tmp_pngs = os.path.join(self.movie_path, "tmp_pngs")
shutil.move(
os.path.join(tmp_pngs, self.movie_name + ".gif"),
os.path.join(
self.movie_path,
self.movie_name
+ "_"
+ str(datetime.now().strftime("%Y%m%d-%H%M%S"))
+ ".gif",
),
)
self._nb_step += 1
# TransitionOutcome and Value are scikit-decide types
return TransitionOutcome(
state=next_state, value=Value(reward=reward), termination=termination
)

def _get_action_space_(self) -> Space[D.T_event]:
# Cast to skdecide's GymSpace
return GymSpace(self.rddl_gym_env.action_space)

def _state_reset(self) -> D.T_state:
self._nb_step = 0
# SkDecide only needs the state, not the info
return self.rddl_gym_env.reset()[0]

def _get_observation_space_(self) -> Space[D.T_observation]:
# Cast to skdecide's GymSpace
return GymSpace(self.rddl_gym_env.observation_space)

def _render_from(self, memory: D.T_state = None, **kwargs: Any) -> Any:
# We do not want the image to be displayed in a pygame window, but rather in this notebook
rddl_gym_img = self.rddl_gym_env.render(to_display=self.display_with_pygame)
if self.display_within_jupyter and ipython_available:
clear_output(wait=True)
display(rddl_gym_img)
return rddl_gym_img


class RDDLDomainRL(RDDLDomain):
def __init__(
self,
rddl_domain: str,
rddl_instance: str,
base_class: type[RDDLEnv] = SimplifiedActionRDDLEnv,
backend: type[RDDLSimulator] = RDDLSimulator,
display_with_pygame: bool = True,
display_within_jupyter: bool = False,
visualizer: BaseViz = ChartVisualizer,
movie_name: str = None,
movie_dir: str = "rddl_movies",
max_frames=1000,
enforce_action_constraints=True,
**kwargs
):
super().__init__(
rddl_domain=rddl_domain,
rddl_instance=rddl_instance,
base_class=base_class,
backend=backend,
display_with_pygame=display_with_pygame,
display_within_jupyter=display_within_jupyter,
visualizer=visualizer,
movie_name=movie_name,
movie_dir=movie_dir,
max_frames=max_frames,
enforce_action_constraints=enforce_action_constraints,
**kwargs
)


class RDDLDomainSimplifiedSpaces(RDDLDomainRL):
def _state_step(
self, action: D.T_event
) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]:
outcome = super()._state_step(action)
return TransitionOutcome(
state=flatten(self.rddl_gym_env.observation_space, outcome.state),
value=outcome.value,
termination=outcome.termination,
)

def _get_action_space_(self) -> Space[D.T_event]:
return GymSpace(flatten_space(self.rddl_gym_env.action_space))

def _state_reset(self) -> D.T_state:
# SkDecide only needs the state, not the info
return flatten(self.rddl_gym_env.observation_space, super()._state_reset())

def _get_observation_space_(self) -> Space[D.T_observation]:
return GymSpace(flatten_space(self.rddl_gym_env.observation_space))
1 change: 1 addition & 0 deletions skdecide/hub/solver/rddl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rddl import RDDLGurobiSolver, RDDLJaxSolver
117 changes: 117 additions & 0 deletions skdecide/hub/solver/rddl/rddl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from collections.abc import Callable
from typing import Any, Optional

from pyRDDLGym_jax.core.planner import (
JaxBackpropPlanner,
JaxOfflineController,
JaxOnlineController,
load_config,
)
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator

from skdecide import Solver
from skdecide.builders.solver import FromInitialState, Policies
from skdecide.hub.domain.rddl import RDDLDomain

try:
from pyRDDLGym_gurobi.core.planner import (
GurobiOnlineController,
GurobiPlan,
GurobiStraightLinePlan,
)
except ImportError:
pyrddlgym_gurobi_available = False
else:
pyrddlgym_gurobi_available = True


class D(RDDLDomain):
pass


class RDDLJaxSolver(Solver, Policies, FromInitialState):
T_domain = D

def __init__(
self, domain_factory: Callable[[], RDDLDomain], config: Optional[str] = None
):
Solver.__init__(self, domain_factory=domain_factory)
self._domain = domain_factory()
if config is not None:
self.planner_args, _, self.train_args = load_config(config)

@classmethod
def _check_domain_additional(cls, domain: D) -> bool:
return hasattr(domain, "rddl_gym_env")

def _solve(self, from_memory: Optional[D.T_state] = None) -> None:
planner = JaxBackpropPlanner(
rddl=self._domain.rddl_gym_env.model,
**(self.planner_args if self.planner_args is not None else {})
)
self.controller = JaxOfflineController(
planner, **(self.train_args if self.train_args is not None else {})
)

def _sample_action(self, observation: D.T_observation) -> D.T_event:
return self.controller.sample_action(observation)

def _is_policy_defined_for(self, observation: D.T_observation) -> bool:
return True


if pyrddlgym_gurobi_available:

class D(RDDLDomain):
pass

class RDDLGurobiSolver(Solver, Policies, FromInitialState):
T_domain = D

def __init__(
self,
domain_factory: Callable[[], RDDLDomain],
plan: Optional[GurobiPlan] = None,
rollout_horizon=5,
model_params: Optional[dict[str, Any]] = None,
):
Solver.__init__(self, domain_factory=domain_factory)
self._domain = domain_factory()
self.rollout_horizon = rollout_horizon
if plan is None:
self.plan = GurobiStraightLinePlan()
else:
self.plan = plan
if model_params is None:
self.model_params = {"NonConvex": 2, "OutputFlag": 0}
else:
self.model_params = model_params

@classmethod
def _check_domain_additional(cls, domain: D) -> bool:
return hasattr(domain, "rddl_gym_env")

def _solve(self, from_memory: Optional[D.T_state] = None) -> None:
self.controller = GurobiOnlineController(
rddl=self._domain.rddl_gym_env.model,
plan=self.plan,
rollout_horizon=self.rollout_horizon,
model_params=self.model_params,
)

def _sample_action(self, observation: D.T_observation) -> D.T_event:
return self.controller.sample_action(observation)

def _is_policy_defined_for(self, observation: D.T_observation) -> bool:
return True

else:

class RDDLGurobiSolver(Solver, Policies, FromInitialState):
T_domain = D

def __init__(self, domain_factory: Callable[[], RDDLDomain], rollout_horizon=5):
raise RuntimeError(
"You need pyRDDLGym-gurobi installed for this solver. "
"See https://github.com/pyrddlgym-project/pyRDDLGym-gurobi for more information."
)
Loading

0 comments on commit 17987d8

Please sign in to comment.