diff --git a/examples/scheduling/computation_time_experiments.py b/examples/scheduling/computation_time_experiments.py index 1e3c501ac8..8117d8eac7 100644 --- a/examples/scheduling/computation_time_experiments.py +++ b/examples/scheduling/computation_time_experiments.py @@ -2,19 +2,11 @@ from rcpsp_datasets import get_complete_path -from skdecide import DiscreteDistribution, rollout_episode -from skdecide.builders.domain.scheduling.scheduling_domains_modelling import ( - SchedulingAction, - SchedulingActionEnum, - State, -) -from skdecide.hub.domain.rcpsp.rcpsp_sk import MSRCPSP, RCPSP -from skdecide.hub.domain.rcpsp.rcpsp_sk_parser import ( - load_domain, - load_multiskill_domain, -) +from skdecide import rollout_episode +from skdecide.hub.domain.rcpsp.rcpsp_sk import RCPSP +from skdecide.hub.domain.rcpsp.rcpsp_sk_parser import load_domain from skdecide.hub.solver.do_solver.do_solver_scheduling import DOSolver, SolvingMethod -from skdecide.hub.solver.sgs_policies.sgs_policies import ( +from skdecide.hub.solver.do_solver.sgs_policies import ( BasePolicyMethod, PolicyMethodParams, ) diff --git a/examples/scheduling/gphh_example.py b/examples/scheduling/gphh_example.py index 7c8eb15c9d..4d2191eebc 100644 --- a/examples/scheduling/gphh_example.py +++ b/examples/scheduling/gphh_example.py @@ -10,7 +10,7 @@ from skdecide.hub.domain.rcpsp.rcpsp_sk import RCPSP from skdecide.hub.domain.rcpsp.rcpsp_sk_parser import load_domain from skdecide.hub.solver.do_solver.do_solver_scheduling import DOSolver, SolvingMethod -from skdecide.hub.solver.gphh.gphh import ( +from skdecide.hub.solver.do_solver.gphh import ( GPHH, EvaluationGPHH, FeatureEnum, @@ -29,7 +29,7 @@ min_operator, protected_div, ) -from skdecide.hub.solver.sgs_policies.sgs_policies import ( +from skdecide.hub.solver.do_solver.sgs_policies import ( BasePolicyMethod, PolicyMethodParams, ) diff --git a/examples/scheduling/policy_sgs_works.py b/examples/scheduling/policy_sgs_works.py index a19423481c..640ac58911 100644 --- a/examples/scheduling/policy_sgs_works.py +++ b/examples/scheduling/policy_sgs_works.py @@ -10,14 +10,14 @@ SolvingMethod, from_solution_to_policy, ) +from skdecide.hub.solver.do_solver.sgs_policies import ( + BasePolicyMethod, + PolicyMethodParams, +) from skdecide.hub.solver.meta_policy.meta_policies import MetaPolicy from skdecide.hub.solver.policy_evaluators.policy_evaluator import ( rollout_based_compute_expected_cost_for_policy_scheduling, ) -from skdecide.hub.solver.sgs_policies.sgs_policies import ( - BasePolicyMethod, - PolicyMethodParams, -) # Compare different online policies based on permutation on few sampled scenarios. diff --git a/examples/scheduling/rcpsp_examples.py b/examples/scheduling/rcpsp_examples.py index e7395ac7fb..6a4ee05c5e 100644 --- a/examples/scheduling/rcpsp_examples.py +++ b/examples/scheduling/rcpsp_examples.py @@ -7,7 +7,7 @@ load_multiskill_domain, ) from skdecide.hub.solver.do_solver.do_solver_scheduling import DOSolver, SolvingMethod -from skdecide.hub.solver.sgs_policies.sgs_policies import ( +from skdecide.hub.solver.do_solver.sgs_policies import ( BasePolicyMethod, PolicyMethodParams, ) diff --git a/pyproject.toml b/pyproject.toml index 2e75f4d99a..901634ea63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,7 +161,7 @@ RayRLlib = "skdecide.hub.solver.ray_rllib:RayRLlib [solvers]" SimpleGreedy = "skdecide.hub.solver.simple_greedy:SimpleGreedy [solvers]" StableBaseline = "skdecide.hub.solver.stable_baselines:StableBaseline [solvers]" DOSolver = "skdecide.hub.solver.do_solver:DOSolver [solvers]" -GPHH = "skdecide.hub.solver.gphh:GPHH [solvers]" +GPHH = "skdecide.hub.solver.do_solver:GPHH [solvers]" PilePolicy = "skdecide.hub.solver.pile_policy:PilePolicy [solvers]" UPSolver = "skdecide.hub.solver.up:UPSolver [solvers]" diff --git a/skdecide/hub/solver/do_solver/__init__.py b/skdecide/hub/solver/do_solver/__init__.py index 5b16c66aa7..77b45ae6cc 100644 --- a/skdecide/hub/solver/do_solver/__init__.py +++ b/skdecide/hub/solver/do_solver/__init__.py @@ -3,3 +3,15 @@ # LICENSE file in the root directory of this source tree. from .do_solver_scheduling import DOSolver +from .gphh import ( + GPHH, + EvaluationGPHH, + FeatureEnum, + FixedPermutationPolicy, + GPHHPolicy, + ParametersGPHH, + PermutationDistance, + PoolAggregationMethod, + PooledGPHHPolicy, +) +from .sgs_policies import BasePolicyMethod, PolicyMethodParams, PolicyRCPSP diff --git a/skdecide/hub/solver/do_solver/do_solver_scheduling.py b/skdecide/hub/solver/do_solver/do_solver_scheduling.py index 885008b788..311434db33 100644 --- a/skdecide/hub/solver/do_solver/do_solver_scheduling.py +++ b/skdecide/hub/solver/do_solver/do_solver_scheduling.py @@ -5,8 +5,13 @@ from __future__ import annotations from enum import Enum -from typing import Any, Callable, Dict, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union +from discrete_optimization.generic_tools.callbacks.callback import Callback +from discrete_optimization.generic_tools.do_solver import SolverDO +from discrete_optimization.generic_tools.result_storage.result_storage import ( + ResultStorage, +) from discrete_optimization.rcpsp.rcpsp_model import RCPSPModel, RCPSPSolution from discrete_optimization.rcpsp_multiskill.rcpsp_multiskill import ( MS_RCPSPModel, @@ -14,13 +19,10 @@ MS_RCPSPSolution_Variant, ) +from skdecide import Domain from skdecide.builders.domain.scheduling.scheduling_domains import SchedulingDomain +from skdecide.hub.solver.do_solver.sgs_policies import PolicyMethodParams, PolicyRCPSP from skdecide.hub.solver.do_solver.sk_to_do_binding import build_do_domain -from skdecide.hub.solver.sgs_policies.sgs_policies import ( - BasePolicyMethod, - PolicyMethodParams, - PolicyRCPSP, -) from skdecide.solvers import DeterministicPolicies, Solver @@ -29,70 +31,46 @@ class D(SchedulingDomain): class SolvingMethod(Enum): - PILE = 0 - GA = 1 - LS = 2 - LP = 3 - CP = 4 - LNS_LP = 5 - LNS_CP = 6 - LNS_CP_CALENDAR = 7 - # New algorithm, similar to lns, adding iterativelyu constraint to fulfill calendar constraints.. - - -def build_solver(solving_method: SolvingMethod, do_domain): + PILE = "greedy" + GA = "ga" + LS = "ls" + LP = "lp" + CP = "cp" + LNS_LP = "lns-lp" + LNS_CP = "lns-scheduling" + + +def build_solver( + solving_method: SolvingMethod, do_domain +) -> Tuple[SolverDO, Dict[str, Any]]: if isinstance(do_domain, RCPSPModel): from discrete_optimization.rcpsp.rcpsp_solvers import ( look_for_solver, solvers_map, ) - available = look_for_solver(do_domain) - solving_method_to_str = { - SolvingMethod.PILE: "greedy", - SolvingMethod.GA: "ga", - SolvingMethod.LS: "ls", - SolvingMethod.LP: "lp", - SolvingMethod.CP: "cp", - SolvingMethod.LNS_LP: "lns-lp", - SolvingMethod.LNS_CP: "lns-cp", - SolvingMethod.LNS_CP_CALENDAR: "lns-cp-calendar", - } - smap = [ - (av, solvers_map[av]) - for av in available - if solvers_map[av][0] == solving_method_to_str[solving_method] - ] - if len(smap) > 0: - return smap[0] - if isinstance(do_domain, MS_RCPSPModel): + do_domain_cls = RCPSPModel + elif isinstance(do_domain, MS_RCPSPModel): from discrete_optimization.rcpsp_multiskill.rcpsp_multiskill_solvers import ( look_for_solver, solvers_map, ) - available = look_for_solver(do_domain) - solving_method_to_str = { - SolvingMethod.PILE: "greedy", - SolvingMethod.GA: "ga", - SolvingMethod.LS: "ls", - SolvingMethod.LP: "lp", - SolvingMethod.CP: "cp", - SolvingMethod.LNS_LP: "lns-lp", - SolvingMethod.LNS_CP: "lns-cp", - SolvingMethod.LNS_CP_CALENDAR: "lns-cp-calendar", - } - - smap = [ - (av, solvers_map[av]) - for av in available - if solvers_map[av][0] == solving_method_to_str[solving_method] - ] - - if len(smap) > 0: - return smap[0] - - return None + do_domain_cls = MS_RCPSPModel + else: + raise ValueError("do_domain should be either a RCPSPModel or a MS_RCPSPModel.") + available = look_for_solver(do_domain) + smap = [ + (av, solvers_map[av]) + for av in available + if solvers_map[av][0] == solving_method.value + ] + if len(smap) > 0: + return smap[0] + else: + raise ValueError( + f"solving_method {solving_method} not available for {do_domain_cls}." + ) def from_solution_to_policy( @@ -158,8 +136,10 @@ def __init__( self, policy_method_params: PolicyMethodParams, method: SolvingMethod = SolvingMethod.PILE, - dict_params: Dict[Any, Any] = None, + dict_params: Optional[Dict[Any, Any]] = None, + callback: Optional[Callable[[Domain, DOSolver], bool]] = None, ): + self.callback = callback self.method = method self.policy_method_params = policy_method_params self.dict_params = dict_params @@ -196,12 +176,20 @@ def _solve(self, domain_factory: Callable[[], D]) -> None: if k not in self.dict_params: self.dict_params[k] = params[k] + # callbacks + if self.callback is None: + callbacks = [] + else: + callbacks = [ + _DOCallback(callback=self.callback, domain=self.domain, solver=self) + ] + self.solver = solver_class(self.do_domain, **self.dict_params) if hasattr(self.solver, "init_model") and callable(self.solver.init_model): self.solver.init_model(**self.dict_params) - result_storage = self.solver.solve(**self.dict_params) + result_storage = self.solver.solve(callbacks=callbacks, **self.dict_params) best_solution: RCPSPSolution = result_storage.get_best_solution() assert best_solution is not None @@ -233,3 +221,32 @@ def _get_next_action( def _is_policy_defined_for(self, observation: D.T_agent[D.T_observation]) -> bool: return self.policy_object.is_policy_defined_for(observation=observation) + + +class _DOCallback(Callback): + def __init__( + self, + callback: Callable[[Domain, DOSolver], bool], + domain: Domain, + solver: Solver, + ): + self.domain = domain + self.solver = solver + self.callback = callback + + def on_step_end( + self, step: int, res: ResultStorage, solver: SolverDO + ) -> Optional[bool]: + """Called at the end of an optimization step. + + Args: + step: index of step + res: current result storage + solver: solvers using the callback + + Returns: + If `True`, the optimization process is stopped, else it goes on. + + """ + stopping = self.callback(self.domain, self.solver) + return stopping diff --git a/skdecide/hub/solver/gphh/gphh.py b/skdecide/hub/solver/do_solver/gphh.py similarity index 99% rename from skdecide/hub/solver/gphh/gphh.py rename to skdecide/hub/solver/do_solver/gphh.py index 4cd7e9defc..4cc1425c17 100644 --- a/skdecide/hub/solver/gphh/gphh.py +++ b/skdecide/hub/solver/do_solver/gphh.py @@ -34,8 +34,8 @@ PolicyRCPSP, SolvingMethod, ) +from skdecide.hub.solver.do_solver.sgs_policies import BasePolicyMethod from skdecide.hub.solver.do_solver.sk_to_do_binding import build_do_domain -from skdecide.hub.solver.sgs_policies.sgs_policies import BasePolicyMethod def if_then_else(input, output1, output2): diff --git a/skdecide/hub/solver/sgs_policies/sgs_policies.py b/skdecide/hub/solver/do_solver/sgs_policies.py similarity index 100% rename from skdecide/hub/solver/sgs_policies/sgs_policies.py rename to skdecide/hub/solver/do_solver/sgs_policies.py diff --git a/skdecide/hub/solver/gphh/__init__.py b/skdecide/hub/solver/gphh/__init__.py index f819191a27..91d6e6714d 100644 --- a/skdecide/hub/solver/gphh/__init__.py +++ b/skdecide/hub/solver/gphh/__init__.py @@ -1,15 +1,3 @@ # Copyright (c) AIRBUS and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .gphh import ( - GPHH, - EvaluationGPHH, - FeatureEnum, - FixedPermutationPolicy, - GPHHPolicy, - ParametersGPHH, - PermutationDistance, - PoolAggregationMethod, - PooledGPHHPolicy, -) diff --git a/skdecide/hub/solver/sgs_policies/__init__.py b/skdecide/hub/solver/sgs_policies/__init__.py index e3064e2255..91d6e6714d 100644 --- a/skdecide/hub/solver/sgs_policies/__init__.py +++ b/skdecide/hub/solver/sgs_policies/__init__.py @@ -1,5 +1,3 @@ # Copyright (c) AIRBUS and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from .sgs_policies import BasePolicyMethod, PolicyMethodParams, PolicyRCPSP diff --git a/tests/scheduling/test_scheduling.py b/tests/scheduling/test_scheduling.py index da8a4a3b4f..289572f14e 100644 --- a/tests/scheduling/test_scheduling.py +++ b/tests/scheduling/test_scheduling.py @@ -1,3 +1,4 @@ +import logging import random from enum import Enum from typing import Any, Dict, List, Optional, Set, Union @@ -35,21 +36,18 @@ rebuild_tasks_modes_dict, ) from skdecide.builders.domain.scheduling.task_duration import DeterministicTaskDuration -from skdecide.hub.domain.rcpsp.rcpsp_sk import ( - MRCPSP, - RCPSP, - build_n_determinist_from_stochastic, -) -from skdecide.hub.solver.do_solver.do_solver_scheduling import ( +from skdecide.hub.domain.rcpsp.rcpsp_sk import build_n_determinist_from_stochastic +from skdecide.hub.solver.do_solver.do_solver_scheduling import DOSolver, SolvingMethod +from skdecide.hub.solver.do_solver.gphh import GPHH, ParametersGPHH +from skdecide.hub.solver.do_solver.sgs_policies import ( BasePolicyMethod, - DOSolver, PolicyMethodParams, - SolvingMethod, ) -from skdecide.hub.solver.gphh.gphh import GPHH, ParametersGPHH from skdecide.hub.solver.graph_explorer.DFS_Uncertain_Exploration import DFSExploration from skdecide.hub.solver.lazy_astar import LazyAstar +logger = logging.getLogger(__name__) + optimal_solutions = { "ToyRCPSPDomain": {"makespan": 10}, "ToyMS_RCPSPDomain": {"makespan": 10}, @@ -944,3 +942,60 @@ def test_sgs_policies(domain): ) print("Cost :", sum([v.cost for v in values])) check_rollout_consistency(domain, states) + + +class MyCallback: + """Callback for testing. + + - displays iteration number + - stops after max iteration reached + - check classes of domain and solver + + """ + + def __init__(self, max_iter=2): + self.max_iter = max_iter + self.iter = 0 + + def __call__(self, domain, solver): + self.iter += 1 + logger.warning(f"End of iteration #{self.iter}.") + assert isinstance(domain, ToyRCPSPDomain) + assert isinstance(solver, DOSolver) + stopping = self.iter >= self.max_iter + return stopping + + +def test_do_with_cb(caplog): + domain = ToyRCPSPDomain() + domain.set_inplace_environment(False) + state = domain.get_initial_state() + print("Initial state : ", state) + solver = DOSolver( + policy_method_params=PolicyMethodParams( + base_policy_method=BasePolicyMethod.SGS_PRECEDENCE, + delta_index_freedom=0, + delta_time_freedom=0, + ), + method=SolvingMethod.LNS_CP, + callback=MyCallback(), + # dict_params={"cp_solver_name": CPSolverName.GECODE} + ) + solver.solve(domain_factory=lambda: domain) + + # Check that 2 iterations were done and messages logged by callback + assert "End of iteration #2" in caplog.text + assert "End of iteration #3" not in caplog.text + + # action_formatter=lambda o: str(o), + # outcome_formatter=lambda o: f'{o.observation} - cost: {o.value.cost:.2f}') + states, actions, values = rollout_episode( + domain=domain, + max_steps=1000, + solver=solver, + from_memory=state, + action_formatter=None, + outcome_formatter=None, + verbose=False, + ) + check_rollout_consistency(domain, states)