Skip to content

Commit

Permalink
Add callback to DOSolver
Browse files Browse the repository at this point in the history
  • Loading branch information
nhuet committed May 10, 2024
1 parent d70f565 commit 848aee8
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 7 deletions.
50 changes: 47 additions & 3 deletions skdecide/hub/solver/do_solver/do_solver_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Callable, Dict, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

from discrete_optimization.generic_tools.callbacks.callback import Callback
from discrete_optimization.generic_tools.do_solver import SolverDO
from discrete_optimization.generic_tools.result_storage.result_storage import (
ResultStorage,
)
from discrete_optimization.rcpsp.rcpsp_model import RCPSPModel, RCPSPSolution
from discrete_optimization.rcpsp_multiskill.rcpsp_multiskill import (
MS_RCPSPModel,
MS_RCPSPSolution,
MS_RCPSPSolution_Variant,
)

from skdecide import Domain
from skdecide.builders.domain.scheduling.scheduling_domains import SchedulingDomain
from skdecide.hub.solver.do_solver.sgs_policies import PolicyMethodParams, PolicyRCPSP
from skdecide.hub.solver.do_solver.sk_to_do_binding import build_do_domain
Expand Down Expand Up @@ -131,8 +136,10 @@ def __init__(
self,
policy_method_params: PolicyMethodParams,
method: SolvingMethod = SolvingMethod.PILE,
dict_params: Dict[Any, Any] = None,
dict_params: Optional[Dict[Any, Any]] = None,
callback: Optional[Callable[[Domain, DOSolver], bool]] = None,
):
self.callback = callback
self.method = method
self.policy_method_params = policy_method_params
self.dict_params = dict_params
Expand Down Expand Up @@ -169,12 +176,20 @@ def _solve(self, domain_factory: Callable[[], D]) -> None:
if k not in self.dict_params:
self.dict_params[k] = params[k]

# callbacks
if self.callback is None:
callbacks = []
else:
callbacks = [
_DOCallback(callback=self.callback, domain=self.domain, solver=self)
]

self.solver = solver_class(self.do_domain, **self.dict_params)

if hasattr(self.solver, "init_model") and callable(self.solver.init_model):
self.solver.init_model(**self.dict_params)

result_storage = self.solver.solve(**self.dict_params)
result_storage = self.solver.solve(callbacks=callbacks, **self.dict_params)
best_solution: RCPSPSolution = result_storage.get_best_solution()

assert best_solution is not None
Expand Down Expand Up @@ -206,3 +221,32 @@ def _get_next_action(

def _is_policy_defined_for(self, observation: D.T_agent[D.T_observation]) -> bool:
return self.policy_object.is_policy_defined_for(observation=observation)


class _DOCallback(Callback):
def __init__(
self,
callback: Callable[[Domain, DOSolver], bool],
domain: Domain,
solver: Solver,
):
self.domain = domain
self.solver = solver
self.callback = callback

def on_step_end(
self, step: int, res: ResultStorage, solver: SolverDO
) -> Optional[bool]:
"""Called at the end of an optimization step.
Args:
step: index of step
res: current result storage
solver: solvers using the callback
Returns:
If `True`, the optimization process is stopped, else it goes on.
"""
stopping = self.callback(self.domain, self.solver)
return stopping
67 changes: 63 additions & 4 deletions tests/scheduling/test_scheduling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import random
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
Expand Down Expand Up @@ -36,16 +37,17 @@
)
from skdecide.builders.domain.scheduling.task_duration import DeterministicTaskDuration
from skdecide.hub.domain.rcpsp.rcpsp_sk import build_n_determinist_from_stochastic
from skdecide.hub.solver.do_solver.do_solver_scheduling import (
from skdecide.hub.solver.do_solver.do_solver_scheduling import DOSolver, SolvingMethod
from skdecide.hub.solver.do_solver.gphh import GPHH, ParametersGPHH
from skdecide.hub.solver.do_solver.sgs_policies import (
BasePolicyMethod,
DOSolver,
PolicyMethodParams,
SolvingMethod,
)
from skdecide.hub.solver.do_solver.gphh import GPHH, ParametersGPHH
from skdecide.hub.solver.graph_explorer.DFS_Uncertain_Exploration import DFSExploration
from skdecide.hub.solver.lazy_astar import LazyAstar

logger = logging.getLogger(__name__)

optimal_solutions = {
"ToyRCPSPDomain": {"makespan": 10},
"ToyMS_RCPSPDomain": {"makespan": 10},
Expand Down Expand Up @@ -940,3 +942,60 @@ def test_sgs_policies(domain):
)
print("Cost :", sum([v.cost for v in values]))
check_rollout_consistency(domain, states)


class MyCallback:
"""Callback for testing.
- displays iteration number
- stops after max iteration reached
- check classes of domain and solver
"""

def __init__(self, max_iter=2):
self.max_iter = max_iter
self.iter = 0

def __call__(self, domain, solver):
self.iter += 1
logger.warning(f"End of iteration #{self.iter}.")
assert isinstance(domain, ToyRCPSPDomain)
assert isinstance(solver, DOSolver)
stopping = self.iter >= self.max_iter
return stopping


def test_do_with_cb(caplog):
domain = ToyRCPSPDomain()
domain.set_inplace_environment(False)
state = domain.get_initial_state()
print("Initial state : ", state)
solver = DOSolver(
policy_method_params=PolicyMethodParams(
base_policy_method=BasePolicyMethod.SGS_PRECEDENCE,
delta_index_freedom=0,
delta_time_freedom=0,
),
method=SolvingMethod.LNS_CP,
callback=MyCallback(),
# dict_params={"cp_solver_name": CPSolverName.GECODE}
)
solver.solve(domain_factory=lambda: domain)

# Check that 2 iterations were done and messages logged by callback
assert "End of iteration #2" in caplog.text
assert "End of iteration #3" not in caplog.text

# action_formatter=lambda o: str(o),
# outcome_formatter=lambda o: f'{o.observation} - cost: {o.value.cost:.2f}')
states, actions, values = rollout_episode(
domain=domain,
max_steps=1000,
solver=solver,
from_memory=state,
action_formatter=None,
outcome_formatter=None,
verbose=False,
)
check_rollout_consistency(domain, states)

0 comments on commit 848aee8

Please sign in to comment.