diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py index 62483a1777..55c9ecc6a6 100644 --- a/smartsim/_core/control/controller.py +++ b/smartsim/_core/control/controller.py @@ -88,7 +88,7 @@ # job manager lock JM_LOCK = threading.RLock() -tmonitor: t.Optional[t.Any] = None +TELEMETRY_MONITOR: t.Optional[t.Any] = None def start_telemetry_monitor( @@ -114,43 +114,12 @@ def start_telemetry_monitor( return process -def stop_telemetry_monitor(process: t.Any) -> None: - # if self._telemetry is None: - # return +def start_telemetry_callback(_job: Job, _logger: Logger) -> None: + global TELEMETRY_MONITOR # pylint: disable=global-statement - logger.debug("Stopping telemetry monitor process") - - try: - process.terminate() - except Exception: - logger.warning( - "An error occurred while terminating the telemetry monitor", exc_info=True - ) - - -def start_telemetry_callback_wrapper( - _manager: JobManager, -) -> t.Callable[[Job, Logger], None]: - def start_telemetry_callback(_job: Job, _logger: Logger) -> None: - global tmonitor # pylint: disable=global-statement - # if not os.environ.get("SMARTSIM_TELEMETRY_ENABLED", False): - # return - - if tmonitor is None: - tmonitor = start_telemetry_monitor() - - return start_telemetry_callback - - -def stop_telemetry_callback_wrapper( - _manager: JobManager, -) -> t.Callable[[Job, Logger], None]: - def stop_telemetry_callback(_job: Job, _logger: Logger) -> None: - # global tmonitor # pylint: disable=global-statement - if tmonitor is not None: - stop_telemetry_monitor(tmonitor) - - return stop_telemetry_callback + # if never started or if a prior instance has shut down + if TELEMETRY_MONITOR is None or TELEMETRY_MONITOR.returncode is not None: + TELEMETRY_MONITOR = start_telemetry_monitor() class Controller: @@ -166,10 +135,8 @@ def __init__(self, launcher: str = "local") -> None: :type launcher: str """ self._jobs = JobManager(JM_LOCK) - self._jobs.add_job_onstart_callback( - start_telemetry_callback_wrapper(self._jobs) - ) - self._jobs.add_job_onstop_callback(stop_telemetry_callback_wrapper(self._jobs)) + # have job manager launch a telemetry monitor if any jobs are run + self._jobs.add_job_onstart_callback(start_telemetry_callback) self.init_launcher(launcher) @@ -769,41 +736,6 @@ def _orchestrator_launch_wait(self, orchestrator: Orchestrator) -> None: # launch explicitly raise - # def start_telemetry_monitor(self, - # exp_dir: str = ".", - # frequency: int = 10) -> None: - # if self._telemetry is None: - # logger.debug("Starting telemetry monitor process") - # self._telemetry = subprocess.Popen( - # [ - # sys.executable, - # "-m", - # "smartsim._core.entrypoints.telemetrymonitor", - # "-d", - # exp_dir, - # "-f", - # str(frequency) - # ], - # stdin=subprocess.PIPE, - # stdout=subprocess.PIPE, - # cwd=str(pathlib.Path(__file__).parent.parent.parent), - # shell=False, - # ) - - # def stop_telemetry_monitor(self) -> None: - # if self._telemetry is None: - # return - - # logger.debug("Stopping telemetry monitor process") - - # try: - # self._telemetry.terminate() - # except Exception: - # logger.warn("An error occurred while terminating the telemetry monitor", - # exc_info=True) - # finally: - # self._telemetry = None - def reload_saved_db(self, checkpoint_file: str) -> Orchestrator: with JM_LOCK: if self.orchestrator_active: diff --git a/smartsim/_core/control/jobmanager.py b/smartsim/_core/control/jobmanager.py index 7b2e41f94e..0371683c71 100644 --- a/smartsim/_core/control/jobmanager.py +++ b/smartsim/_core/control/jobmanager.py @@ -235,9 +235,6 @@ def add_job( else: self.jobs[entity.name] = job - # if self._telemetry is None or not self._telemetry.is_alive(): - # self._telemetry = start_monitor() - for hook in self.on_start_hook: hook(job, logger) diff --git a/smartsim/_core/entrypoints/telemetrymonitor.py b/smartsim/_core/entrypoints/telemetrymonitor.py index bd4b994aee..3cf6f35f1e 100644 --- a/smartsim/_core/entrypoints/telemetrymonitor.py +++ b/smartsim/_core/entrypoints/telemetrymonitor.py @@ -337,7 +337,7 @@ def __init__( self._completed_jobs: t.Dict[_JobKey, JobEntity] = {} self._launcher_type: str = "" self._launcher: t.Optional[Launcher] = None - self._jm: JobManager = JobManager(threading.RLock()) + self.job_manager: JobManager = JobManager(threading.RLock()) self._launcher_map: t.Dict[str, t.Type[Launcher]] = { "slurm": SlurmLauncher, "pbs": PBSLauncher, @@ -371,12 +371,12 @@ def set_launcher(self, launcher_type: str) -> None: if launcher_type != self._launcher_type: self._launcher_type = launcher_type self._launcher = self.init_launcher(launcher_type) - self._jm.set_launcher(self._launcher) - self._jm.add_job_onstart_callback(track_started) - self._jm.add_job_onstop_callback(track_completed) - self._jm.add_job_onstep_callback(track_timestep) + self.job_manager.set_launcher(self._launcher) + self.job_manager.add_job_onstart_callback(track_started) + self.job_manager.add_job_onstop_callback(track_completed) + self.job_manager.add_job_onstep_callback(track_timestep) - self._jm.start() + self.job_manager.start() @property def launcher(self) -> Launcher: @@ -405,7 +405,7 @@ def process_manifest(self, manifest_path: str) -> None: self.set_launcher(manifest.launcher) - if not self._jm._launcher: # pylint: disable=protected-access + if not self.job_manager._launcher: # pylint: disable=protected-access raise SmartSimError(f"Unable to set launcher from {manifest_path}") runs = [run for run in manifest.runs if run.timestamp not in self._tracked_runs] @@ -431,14 +431,14 @@ def process_manifest(self, manifest_path: str) -> None: self._logger, ) - self._jm.add_job( + self.job_manager.add_job( entity.name, entity.job_id, entity, entity.is_managed, # is_orch=entity.is_db, ) - self._jm._launcher.step_mapping.add( # pylint: disable=protected-access + self.job_manager._launcher.step_mapping.add( # pylint: disable=protected-access entity.name, entity.step_id, entity.step_id, entity.is_managed ) self._tracked_runs[run.timestamp] = run @@ -478,8 +478,8 @@ def _to_completed( if entity.key not in self._completed_jobs: self._completed_jobs[entity.key] = inactive_entity - job = self._jm[entity.name] - self._jm.move_to_completed(job) + job = self.job_manager[entity.name] + self.job_manager.move_to_completed(job) if step_info: detail = f"status: {step_info.status}, error: {step_info.error}" @@ -515,6 +515,18 @@ def on_timestep(self, timestamp: int, exp_dir: pathlib.Path) -> None: self._to_completed(timestamp, completed_entity, exp_dir, step_info) +def shutdown_when_completed( + observer: BaseObserver, action_handler: ManifestEventHandler +) -> None: + """Inspect active and completed job queues and shutdown if all jobs are complete""" + has_running_jobs = ( + not action_handler.job_manager.jobs and not action_handler.job_manager.db_jobs + ) + has_completed_jobs = action_handler.job_manager.completed + if not has_running_jobs and has_completed_jobs: + observer.stop() # type: ignore[no-untyped-call] + + def event_loop( observer: BaseObserver, action_handler: ManifestEventHandler, @@ -523,6 +535,7 @@ def event_loop( num_iters: int, logger: logging.Logger, ) -> None: + """Executes all attached timestep handlers every seconds""" num_iters = num_iters if num_iters > 0 else 0 # ensure non-negative limits remaining = num_iters if num_iters else 0 # track completed iterations @@ -536,6 +549,8 @@ def event_loop( if num_iters and not remaining: break + shutdown_when_completed(observer, action_handler) + def main( frequency: t.Union[int, float],