Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DOSolver for last d-o release and add callback functionality #353

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions examples/scheduling/computation_time_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,11 @@

from rcpsp_datasets import get_complete_path

from skdecide import DiscreteDistribution, rollout_episode
from skdecide.builders.domain.scheduling.scheduling_domains_modelling import (
SchedulingAction,
SchedulingActionEnum,
State,
)
from skdecide.hub.domain.rcpsp.rcpsp_sk import MSRCPSP, RCPSP
from skdecide.hub.domain.rcpsp.rcpsp_sk_parser import (
load_domain,
load_multiskill_domain,
)
from skdecide import rollout_episode
from skdecide.hub.domain.rcpsp.rcpsp_sk import RCPSP
from skdecide.hub.domain.rcpsp.rcpsp_sk_parser import load_domain
from skdecide.hub.solver.do_solver.do_solver_scheduling import DOSolver, SolvingMethod
from skdecide.hub.solver.sgs_policies.sgs_policies import (
from skdecide.hub.solver.do_solver.sgs_policies import (
BasePolicyMethod,
PolicyMethodParams,
)
Expand Down
4 changes: 2 additions & 2 deletions examples/scheduling/gphh_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from skdecide.hub.domain.rcpsp.rcpsp_sk import RCPSP
from skdecide.hub.domain.rcpsp.rcpsp_sk_parser import load_domain
from skdecide.hub.solver.do_solver.do_solver_scheduling import DOSolver, SolvingMethod
from skdecide.hub.solver.gphh.gphh import (
from skdecide.hub.solver.do_solver.gphh import (
GPHH,
EvaluationGPHH,
FeatureEnum,
Expand All @@ -29,7 +29,7 @@
min_operator,
protected_div,
)
from skdecide.hub.solver.sgs_policies.sgs_policies import (
from skdecide.hub.solver.do_solver.sgs_policies import (
BasePolicyMethod,
PolicyMethodParams,
)
Expand Down
8 changes: 4 additions & 4 deletions examples/scheduling/policy_sgs_works.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
SolvingMethod,
from_solution_to_policy,
)
from skdecide.hub.solver.do_solver.sgs_policies import (
BasePolicyMethod,
PolicyMethodParams,
)
from skdecide.hub.solver.meta_policy.meta_policies import MetaPolicy
from skdecide.hub.solver.policy_evaluators.policy_evaluator import (
rollout_based_compute_expected_cost_for_policy_scheduling,
)
from skdecide.hub.solver.sgs_policies.sgs_policies import (
BasePolicyMethod,
PolicyMethodParams,
)


# Compare different online policies based on permutation on few sampled scenarios.
Expand Down
2 changes: 1 addition & 1 deletion examples/scheduling/rcpsp_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
load_multiskill_domain,
)
from skdecide.hub.solver.do_solver.do_solver_scheduling import DOSolver, SolvingMethod
from skdecide.hub.solver.sgs_policies.sgs_policies import (
from skdecide.hub.solver.do_solver.sgs_policies import (
BasePolicyMethod,
PolicyMethodParams,
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ RayRLlib = "skdecide.hub.solver.ray_rllib:RayRLlib [solvers]"
SimpleGreedy = "skdecide.hub.solver.simple_greedy:SimpleGreedy [solvers]"
StableBaseline = "skdecide.hub.solver.stable_baselines:StableBaseline [solvers]"
DOSolver = "skdecide.hub.solver.do_solver:DOSolver [solvers]"
GPHH = "skdecide.hub.solver.gphh:GPHH [solvers]"
GPHH = "skdecide.hub.solver.do_solver:GPHH [solvers]"
PilePolicy = "skdecide.hub.solver.pile_policy:PilePolicy [solvers]"
UPSolver = "skdecide.hub.solver.up:UPSolver [solvers]"

Expand Down
12 changes: 12 additions & 0 deletions skdecide/hub/solver/do_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,15 @@
# LICENSE file in the root directory of this source tree.

from .do_solver_scheduling import DOSolver
from .gphh import (
GPHH,
EvaluationGPHH,
FeatureEnum,
FixedPermutationPolicy,
GPHHPolicy,
ParametersGPHH,
PermutationDistance,
PoolAggregationMethod,
PooledGPHHPolicy,
)
from .sgs_policies import BasePolicyMethod, PolicyMethodParams, PolicyRCPSP
139 changes: 78 additions & 61 deletions skdecide/hub/solver/do_solver/do_solver_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

from discrete_optimization.generic_tools.callbacks.callback import Callback
from discrete_optimization.generic_tools.do_solver import SolverDO
from discrete_optimization.generic_tools.result_storage.result_storage import (
ResultStorage,
)
from discrete_optimization.rcpsp.rcpsp_model import RCPSPModel, RCPSPSolution
from discrete_optimization.rcpsp_multiskill.rcpsp_multiskill import (
MS_RCPSPModel,
MS_RCPSPSolution,
MS_RCPSPSolution_Variant,
)

from skdecide import Domain
from skdecide.builders.domain.scheduling.scheduling_domains import SchedulingDomain
from skdecide.hub.solver.do_solver.sgs_policies import PolicyMethodParams, PolicyRCPSP
from skdecide.hub.solver.do_solver.sk_to_do_binding import build_do_domain
from skdecide.hub.solver.sgs_policies.sgs_policies import (
BasePolicyMethod,
PolicyMethodParams,
PolicyRCPSP,
)
from skdecide.solvers import DeterministicPolicies, Solver


Expand All @@ -29,70 +31,46 @@ class D(SchedulingDomain):


class SolvingMethod(Enum):
PILE = 0
GA = 1
LS = 2
LP = 3
CP = 4
LNS_LP = 5
LNS_CP = 6
LNS_CP_CALENDAR = 7
# New algorithm, similar to lns, adding iterativelyu constraint to fulfill calendar constraints..


def build_solver(solving_method: SolvingMethod, do_domain):
PILE = "greedy"
GA = "ga"
LS = "ls"
LP = "lp"
CP = "cp"
LNS_LP = "lns-lp"
LNS_CP = "lns-scheduling"


def build_solver(
solving_method: SolvingMethod, do_domain
) -> Tuple[SolverDO, Dict[str, Any]]:
if isinstance(do_domain, RCPSPModel):
from discrete_optimization.rcpsp.rcpsp_solvers import (
look_for_solver,
solvers_map,
)

available = look_for_solver(do_domain)
solving_method_to_str = {
SolvingMethod.PILE: "greedy",
SolvingMethod.GA: "ga",
SolvingMethod.LS: "ls",
SolvingMethod.LP: "lp",
SolvingMethod.CP: "cp",
SolvingMethod.LNS_LP: "lns-lp",
SolvingMethod.LNS_CP: "lns-cp",
SolvingMethod.LNS_CP_CALENDAR: "lns-cp-calendar",
}
smap = [
(av, solvers_map[av])
for av in available
if solvers_map[av][0] == solving_method_to_str[solving_method]
]
if len(smap) > 0:
return smap[0]
if isinstance(do_domain, MS_RCPSPModel):
do_domain_cls = RCPSPModel
elif isinstance(do_domain, MS_RCPSPModel):
from discrete_optimization.rcpsp_multiskill.rcpsp_multiskill_solvers import (
look_for_solver,
solvers_map,
)

available = look_for_solver(do_domain)
solving_method_to_str = {
SolvingMethod.PILE: "greedy",
SolvingMethod.GA: "ga",
SolvingMethod.LS: "ls",
SolvingMethod.LP: "lp",
SolvingMethod.CP: "cp",
SolvingMethod.LNS_LP: "lns-lp",
SolvingMethod.LNS_CP: "lns-cp",
SolvingMethod.LNS_CP_CALENDAR: "lns-cp-calendar",
}

smap = [
(av, solvers_map[av])
for av in available
if solvers_map[av][0] == solving_method_to_str[solving_method]
]

if len(smap) > 0:
return smap[0]

return None
do_domain_cls = MS_RCPSPModel
else:
raise ValueError("do_domain should be either a RCPSPModel or a MS_RCPSPModel.")
available = look_for_solver(do_domain)
smap = [
(av, solvers_map[av])
for av in available
if solvers_map[av][0] == solving_method.value
]
if len(smap) > 0:
return smap[0]
else:
raise ValueError(
f"solving_method {solving_method} not available for {do_domain_cls}."
)


def from_solution_to_policy(
Expand Down Expand Up @@ -158,8 +136,10 @@ def __init__(
self,
policy_method_params: PolicyMethodParams,
method: SolvingMethod = SolvingMethod.PILE,
dict_params: Dict[Any, Any] = None,
dict_params: Optional[Dict[Any, Any]] = None,
callback: Optional[Callable[[Domain, DOSolver], bool]] = None,
):
self.callback = callback
self.method = method
self.policy_method_params = policy_method_params
self.dict_params = dict_params
Expand Down Expand Up @@ -196,12 +176,20 @@ def _solve(self, domain_factory: Callable[[], D]) -> None:
if k not in self.dict_params:
self.dict_params[k] = params[k]

# callbacks
if self.callback is None:
callbacks = []
else:
callbacks = [
_DOCallback(callback=self.callback, domain=self.domain, solver=self)
]

self.solver = solver_class(self.do_domain, **self.dict_params)

if hasattr(self.solver, "init_model") and callable(self.solver.init_model):
self.solver.init_model(**self.dict_params)

result_storage = self.solver.solve(**self.dict_params)
result_storage = self.solver.solve(callbacks=callbacks, **self.dict_params)
best_solution: RCPSPSolution = result_storage.get_best_solution()

assert best_solution is not None
Expand Down Expand Up @@ -233,3 +221,32 @@ def _get_next_action(

def _is_policy_defined_for(self, observation: D.T_agent[D.T_observation]) -> bool:
return self.policy_object.is_policy_defined_for(observation=observation)


class _DOCallback(Callback):
def __init__(
self,
callback: Callable[[Domain, DOSolver], bool],
domain: Domain,
solver: Solver,
):
self.domain = domain
self.solver = solver
self.callback = callback

def on_step_end(
self, step: int, res: ResultStorage, solver: SolverDO
) -> Optional[bool]:
"""Called at the end of an optimization step.

Args:
step: index of step
res: current result storage
solver: solvers using the callback

Returns:
If `True`, the optimization process is stopped, else it goes on.

"""
stopping = self.callback(self.domain, self.solver)
return stopping
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
PolicyRCPSP,
SolvingMethod,
)
from skdecide.hub.solver.do_solver.sgs_policies import BasePolicyMethod
from skdecide.hub.solver.do_solver.sk_to_do_binding import build_do_domain
from skdecide.hub.solver.sgs_policies.sgs_policies import BasePolicyMethod


def if_then_else(input, output1, output2):
Expand Down
12 changes: 0 additions & 12 deletions skdecide/hub/solver/gphh/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
# Copyright (c) AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .gphh import (
GPHH,
EvaluationGPHH,
FeatureEnum,
FixedPermutationPolicy,
GPHHPolicy,
ParametersGPHH,
PermutationDistance,
PoolAggregationMethod,
PooledGPHHPolicy,
)
2 changes: 0 additions & 2 deletions skdecide/hub/solver/sgs_policies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# Copyright (c) AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .sgs_policies import BasePolicyMethod, PolicyMethodParams, PolicyRCPSP
Loading
Loading