Skip to content

Commit

Permalink
Merge pull request #6038 from MetRonnie/run-mode-db
Browse files Browse the repository at this point in the history
Refactor run mode restart check to avoid unclosed DB connection
  • Loading branch information
wxtim authored Mar 27, 2024
2 parents ce25a81 + db215d3 commit 4feb7bb
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 69 deletions.
20 changes: 5 additions & 15 deletions cylc/flow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
WorkflowFiles,
check_deprecation,
)
from cylc.flow.workflow_status import RunMode
from cylc.flow.xtrigger_mgr import XtriggerManager

if TYPE_CHECKING:
Expand Down Expand Up @@ -520,7 +521,7 @@ def __init__(

self.process_runahead_limit()

if self.run_mode('simulation', 'dummy'):
if self.run_mode() in {'simulation', 'dummy'}:
self.configure_sim_modes()

self.configure_workflow_state_polling_tasks()
Expand Down Expand Up @@ -1547,20 +1548,9 @@ def process_config_env(self):
os.environ['PATH'] = os.pathsep.join([
os.path.join(self.fdir, 'bin'), os.environ['PATH']])

def run_mode(self, *reqmodes):
"""Return the run mode.
Combine command line option with configuration setting.
If "reqmodes" is specified, return the boolean (mode in reqmodes).
Otherwise, return the mode as a str.
"""
mode = getattr(self.options, 'run_mode', None)
if not mode:
mode = 'live'
if reqmodes:
return mode in reqmodes
else:
return mode
def run_mode(self) -> str:
"""Return the run mode."""
return RunMode.get(self.options)

def _check_task_event_handlers(self):
"""Check custom event handler templates can be expanded.
Expand Down
13 changes: 0 additions & 13 deletions cylc/flow/rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,19 +608,6 @@ def select_workflow_params_restart_count(self):
result = self.connect().execute(stmt).fetchone()
return int(result[0]) if result else 0

def select_workflow_params_run_mode(self):
"""Return original run_mode for workflow_params."""
stmt = rf"""
SELECT
value
FROM
{self.TABLE_WORKFLOW_PARAMS}
WHERE
key == 'run_mode'
""" # nosec (table name is code constant)
result = self.connect().execute(stmt).fetchone()
return result[0] if result else None

def select_workflow_template_vars(self, callback):
"""Select from workflow_template_vars.
Expand Down
51 changes: 21 additions & 30 deletions cylc/flow/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
from cylc.flow.templatevars import eval_var
from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager
from cylc.flow.workflow_events import WorkflowEventHandler
from cylc.flow.workflow_status import StopMode, AutoRestartMode
from cylc.flow.workflow_status import RunMode, StopMode, AutoRestartMode
from cylc.flow import workflow_files
from cylc.flow.taskdef import TaskDef
from cylc.flow.task_events_mgr import TaskEventsManager
Expand Down Expand Up @@ -425,7 +425,14 @@ async def configure(self, params):
self._check_startup_opts()

if self.is_restart:
run_mode = self.get_run_mode()
self._set_workflow_params(params)
# Prevent changing run mode on restart:
og_run_mode = self.get_run_mode()
if run_mode != og_run_mode:
raise InputError(
f'This workflow was originally run in {og_run_mode} mode:'
f' Will not restart in {run_mode} mode.')

self.profiler.log_memory("scheduler.py: before load_flow_file")
try:
Expand All @@ -435,18 +442,6 @@ async def configure(self, params):
# Mark this exc as expected (see docstring for .schd_expected):
exc.schd_expected = True
raise exc

# Prevent changing mode on restart.
if self.is_restart:
# check run mode against db
og_run_mode = self.workflow_db_mgr.get_pri_dao(
).select_workflow_params_run_mode() or 'live'
run_mode = self.config.run_mode()
if run_mode != og_run_mode:
raise InputError(
f'This workflow was originally run in {og_run_mode} mode:'
f' Will not restart in {run_mode} mode.')

self.profiler.log_memory("scheduler.py: after load_flow_file")

self.workflow_db_mgr.on_workflow_start(self.is_restart)
Expand Down Expand Up @@ -603,7 +598,7 @@ def log_start(self) -> None:
# Note that the following lines must be present at the top of
# the workflow log file for use in reference test runs.
LOG.info(
f'Run mode: {self.config.run_mode()}',
f'Run mode: {self.get_run_mode()}',
extra=RotatingLogFileHandler.header_extra
)
LOG.info(
Expand Down Expand Up @@ -1041,7 +1036,7 @@ def command_resume(self) -> None:

def command_poll_tasks(self, items: List[str]) -> int:
"""Poll pollable tasks or a task or family if options are provided."""
if self.config.run_mode('simulation'):
if self.get_run_mode() == RunMode.SIMULATION:
return 0
itasks, _, bad_items = self.pool.filter_task_proxies(items)
self.task_job_mgr.poll_task_jobs(self.workflow, itasks)
Expand All @@ -1050,7 +1045,7 @@ def command_poll_tasks(self, items: List[str]) -> int:
def command_kill_tasks(self, items: List[str]) -> int:
"""Kill all tasks or a task/family if options are provided."""
itasks, _, bad_items = self.pool.filter_task_proxies(items)
if self.config.run_mode('simulation'):
if self.get_run_mode() == RunMode.SIMULATION:
for itask in itasks:
if itask.state(*TASK_STATUSES_ACTIVE):
itask.state_reset(TASK_STATUS_FAILED)
Expand Down Expand Up @@ -1348,6 +1343,9 @@ def _set_workflow_params(
"""
LOG.info('LOADING workflow parameters')
for key, value in params:
if key == self.workflow_db_mgr.KEY_RUN_MODE:
self.options.run_mode = value or RunMode.LIVE
LOG.info(f"+ run mode = {value}")
if value is None:
continue
if key in self.workflow_db_mgr.KEY_INITIAL_CYCLE_POINT_COMPATS:
Expand All @@ -1368,12 +1366,6 @@ def _set_workflow_params(
elif self.options.stopcp is None:
self.options.stopcp = value
LOG.info(f"+ stop point = {value}")
elif (
key == self.workflow_db_mgr.KEY_RUN_MODE
and self.options.run_mode is None
):
self.options.run_mode = value
LOG.info(f"+ run mode = {value}")
elif key == self.workflow_db_mgr.KEY_UUID_STR:
self.uuid_str = value
LOG.info(f"+ workflow UUID = {value}")
Expand Down Expand Up @@ -1419,12 +1411,8 @@ def run_event_handlers(self, event, reason=""):
Run workflow events in simulation and dummy mode ONLY if enabled.
"""
conf = self.config
with suppress(KeyError):
if (
conf.run_mode('simulation', 'dummy')
):
return
if self.get_run_mode() in {RunMode.SIMULATION, RunMode.DUMMY}:
return
self.workflow_event_handler.handle(self, event, str(reason))

def release_queued_tasks(self) -> bool:
Expand Down Expand Up @@ -1497,7 +1485,7 @@ def release_queued_tasks(self) -> bool:
pre_prep_tasks,
self.server.curve_auth,
self.server.client_pub_key_dir,
is_simulation=self.config.run_mode('simulation')
is_simulation=(self.get_run_mode() == RunMode.SIMULATION)
):
if itask.flow_nums:
flow = ','.join(str(i) for i in itask.flow_nums)
Expand Down Expand Up @@ -1548,7 +1536,7 @@ def timeout_check(self):
"""Check workflow and task timers."""
self.check_workflow_timers()
# check submission and execution timeout and polling timers
if not self.config.run_mode('simulation'):
if self.get_run_mode() != RunMode.SIMULATION:
self.task_job_mgr.check_task_jobs(self.workflow, self.pool)

async def workflow_shutdown(self):
Expand Down Expand Up @@ -2207,6 +2195,9 @@ def _check_startup_opts(self) -> None:
f"option --{opt}=reload is only valid for restart"
)

def get_run_mode(self) -> str:
return RunMode.get(self.options)

async def handle_exception(self, exc: BaseException) -> NoReturn:
"""Gracefully shut down the scheduler given a caught exception.
Expand Down
3 changes: 2 additions & 1 deletion cylc/flow/scheduler_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
is_terminal,
prompt,
)
from cylc.flow.workflow_status import RunMode

if TYPE_CHECKING:
from optparse import Values
Expand Down Expand Up @@ -130,7 +131,7 @@
["-m", "--mode"],
help="Run mode: live, dummy, simulation (default live).",
metavar="STRING", action='store', dest="run_mode",
choices=['live', 'dummy', 'simulation'],
choices=[RunMode.LIVE, RunMode.DUMMY, RunMode.SIMULATION],
)

PLAY_RUN_MODE = deepcopy(RUN_MODE)
Expand Down
3 changes: 2 additions & 1 deletion cylc/flow/scripts/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from cylc.flow.templatevars import get_template_vars
from cylc.flow.terminal import cli_function
from cylc.flow.scheduler_cli import RUN_MODE
from cylc.flow.workflow_status import RunMode


VALIDATE_RUN_MODE = deepcopy(RUN_MODE)
Expand Down Expand Up @@ -123,7 +124,7 @@ def get_option_parser():
{
'check_circular': False,
'profile_mode': False,
'run_mode': 'live'
'run_mode': RunMode.LIVE
}
)

Expand Down
11 changes: 6 additions & 5 deletions cylc/flow/task_events_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
get_template_variables as get_workflow_template_variables,
process_mail_footer,
)
from cylc.flow.workflow_status import RunMode


if TYPE_CHECKING:
Expand Down Expand Up @@ -667,7 +668,7 @@ def process_message(
return True
if (
itask.state.status == TASK_STATUS_PREPARING
or itask.tdef.run_mode == 'simulation'
or itask.tdef.run_mode == RunMode.SIMULATION
):
# If not in the preparing state we already assumed and handled
# job submission under the started event above...
Expand All @@ -677,7 +678,7 @@ def process_message(

# ... but either way update the job ID in the job proxy (it only
# comes in via the submission message).
if itask.tdef.run_mode != 'simulation':
if itask.tdef.run_mode != RunMode.SIMULATION:
job_tokens = itask.tokens.duplicate(
job=str(itask.submit_num)
)
Expand Down Expand Up @@ -824,7 +825,7 @@ def _process_message_check(

def setup_event_handlers(self, itask, event, message):
"""Set up handlers for a task event."""
if itask.tdef.run_mode != 'live':
if itask.tdef.run_mode != RunMode.LIVE:
return
msg = ""
if message != f"job {event}":
Expand Down Expand Up @@ -1242,7 +1243,7 @@ def _process_message_submitted(
)

itask.set_summary_time('submitted', event_time)
if itask.tdef.run_mode == 'simulation':
if itask.tdef.run_mode == RunMode.SIMULATION:
# Simulate job started as well.
itask.set_summary_time('started', event_time)
if itask.state_reset(TASK_STATUS_RUNNING):
Expand Down Expand Up @@ -1277,7 +1278,7 @@ def _process_message_submitted(
'submitted',
event_time,
)
if itask.tdef.run_mode == 'simulation':
if itask.tdef.run_mode == RunMode.SIMULATION:
# Simulate job started as well.
self.data_store_mgr.delta_job_time(
job_tokens,
Expand Down
2 changes: 1 addition & 1 deletion cylc/flow/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1731,7 +1731,7 @@ def force_trigger_tasks(

def sim_time_check(self, message_queue: 'Queue[TaskMsg]') -> bool:
"""Simulation mode: simulate task run times and set states."""
if not self.config.run_mode('simulation'):
if self.config.run_mode() != 'simulation':
return False
sim_task_state_changed = False
now = time()
Expand Down
19 changes: 19 additions & 0 deletions cylc/flow/workflow_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from cylc.flow.wallclock import get_time_string_from_unix_time as time2str

if TYPE_CHECKING:
from optparse import Values
from cylc.flow.scheduler import Scheduler

# Keys for identify API call
Expand Down Expand Up @@ -198,3 +199,21 @@ def get_workflow_status(schd: 'Scheduler') -> Tuple[str, str]:
status_msg = 'running'

return (status.value, status_msg)


class RunMode:
"""The possible run modes of a workflow."""

LIVE = 'live'
"""Workflow will run normally."""

SIMULATION = 'simulation'
"""Workflow will run in simulation mode."""

DUMMY = 'dummy'
"""Workflow will run in dummy mode."""

@staticmethod
def get(options: 'Values') -> str:
"""Return the run mode from the options."""
return getattr(options, 'run_mode', None) or RunMode.LIVE
8 changes: 5 additions & 3 deletions tests/integration/test_mode_on_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
import pytest

from cylc.flow.exceptions import InputError
from cylc.flow.scheduler import Scheduler


MODES = [('live'), ('simulation'), ('dummy')]


@pytest.mark.parametrize('mode_before', MODES + [None])
@pytest.mark.parametrize('mode_after', MODES)
@pytest.mark.parametrize('mode_before', MODES + [None])
async def test_restart_mode(
flow, run, scheduler, start, one_conf,
mode_before, mode_after
Expand All @@ -35,12 +36,13 @@ async def test_restart_mode(
N.B - we need use run becuase the check in question only happens
on start.
"""
schd: Scheduler
id_ = flow(one_conf)
schd = scheduler(id_, run_mode=mode_before)
async with start(schd):
if not mode_before:
mode_before = 'live'
assert schd.config.run_mode() == mode_before
assert schd.get_run_mode() == mode_before

schd = scheduler(id_, run_mode=mode_after)

Expand All @@ -50,7 +52,7 @@ async def test_restart_mode(
):
# Restarting in the same mode is fine.
async with run(schd):
assert schd.config.run_mode() == mode_before
assert schd.get_run_mode() == mode_before
else:
# Restarting in a new mode is not:
errormsg = f'^This.*{mode_before} mode: Will.*{mode_after} mode.$'
Expand Down
4 changes: 4 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ ignore=
; Doesn't work at 3.7
B028

per-file-ignores=
; TYPE_CHECKING block suggestions
tests/*: TC001

exclude=
build,
dist,
Expand Down

0 comments on commit 4feb7bb

Please sign in to comment.