Skip to content

Commit

Permalink
auto start/stop telemetry manager
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Oct 19, 2023
1 parent 4525572 commit edda5d7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 90 deletions.
84 changes: 8 additions & 76 deletions smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@

# job manager lock
JM_LOCK = threading.RLock()
tmonitor: t.Optional[t.Any] = None
TELEMETRY_MONITOR: t.Optional[t.Any] = None


def start_telemetry_monitor(
Expand All @@ -114,43 +114,12 @@ def start_telemetry_monitor(
return process


def stop_telemetry_monitor(process: t.Any) -> None:
# if self._telemetry is None:
# return
def start_telemetry_callback(_job: Job, _logger: Logger) -> None:
global TELEMETRY_MONITOR # pylint: disable=global-statement

logger.debug("Stopping telemetry monitor process")

try:
process.terminate()
except Exception:
logger.warning(
"An error occurred while terminating the telemetry monitor", exc_info=True
)


def start_telemetry_callback_wrapper(
_manager: JobManager,
) -> t.Callable[[Job, Logger], None]:
def start_telemetry_callback(_job: Job, _logger: Logger) -> None:
global tmonitor # pylint: disable=global-statement
# if not os.environ.get("SMARTSIM_TELEMETRY_ENABLED", False):
# return

if tmonitor is None:
tmonitor = start_telemetry_monitor()

return start_telemetry_callback


def stop_telemetry_callback_wrapper(
_manager: JobManager,
) -> t.Callable[[Job, Logger], None]:
def stop_telemetry_callback(_job: Job, _logger: Logger) -> None:
# global tmonitor # pylint: disable=global-statement
if tmonitor is not None:
stop_telemetry_monitor(tmonitor)

return stop_telemetry_callback
# if never started or if a prior instance has shut down
if TELEMETRY_MONITOR is None or TELEMETRY_MONITOR.returncode is not None:
TELEMETRY_MONITOR = start_telemetry_monitor()


class Controller:
Expand All @@ -166,10 +135,8 @@ def __init__(self, launcher: str = "local") -> None:
:type launcher: str
"""
self._jobs = JobManager(JM_LOCK)
self._jobs.add_job_onstart_callback(
start_telemetry_callback_wrapper(self._jobs)
)
self._jobs.add_job_onstop_callback(stop_telemetry_callback_wrapper(self._jobs))
# have job manager launch a telemetry monitor if any jobs are run
self._jobs.add_job_onstart_callback(start_telemetry_callback)

self.init_launcher(launcher)

Expand Down Expand Up @@ -769,41 +736,6 @@ def _orchestrator_launch_wait(self, orchestrator: Orchestrator) -> None:
# launch explicitly
raise

# def start_telemetry_monitor(self,
# exp_dir: str = ".",
# frequency: int = 10) -> None:
# if self._telemetry is None:
# logger.debug("Starting telemetry monitor process")
# self._telemetry = subprocess.Popen(
# [
# sys.executable,
# "-m",
# "smartsim._core.entrypoints.telemetrymonitor",
# "-d",
# exp_dir,
# "-f",
# str(frequency)
# ],
# stdin=subprocess.PIPE,
# stdout=subprocess.PIPE,
# cwd=str(pathlib.Path(__file__).parent.parent.parent),
# shell=False,
# )

# def stop_telemetry_monitor(self) -> None:
# if self._telemetry is None:
# return

# logger.debug("Stopping telemetry monitor process")

# try:
# self._telemetry.terminate()
# except Exception:
# logger.warn("An error occurred while terminating the telemetry monitor",
# exc_info=True)
# finally:
# self._telemetry = None

def reload_saved_db(self, checkpoint_file: str) -> Orchestrator:
with JM_LOCK:
if self.orchestrator_active:
Expand Down
3 changes: 0 additions & 3 deletions smartsim/_core/control/jobmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,6 @@ def add_job(
else:
self.jobs[entity.name] = job

# if self._telemetry is None or not self._telemetry.is_alive():
# self._telemetry = start_monitor()

for hook in self.on_start_hook:
hook(job, logger)

Expand Down
37 changes: 26 additions & 11 deletions smartsim/_core/entrypoints/telemetrymonitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def __init__(
self._completed_jobs: t.Dict[_JobKey, JobEntity] = {}
self._launcher_type: str = ""
self._launcher: t.Optional[Launcher] = None
self._jm: JobManager = JobManager(threading.RLock())
self.job_manager: JobManager = JobManager(threading.RLock())
self._launcher_map: t.Dict[str, t.Type[Launcher]] = {
"slurm": SlurmLauncher,
"pbs": PBSLauncher,
Expand Down Expand Up @@ -371,12 +371,12 @@ def set_launcher(self, launcher_type: str) -> None:
if launcher_type != self._launcher_type:
self._launcher_type = launcher_type
self._launcher = self.init_launcher(launcher_type)
self._jm.set_launcher(self._launcher)
self._jm.add_job_onstart_callback(track_started)
self._jm.add_job_onstop_callback(track_completed)
self._jm.add_job_onstep_callback(track_timestep)
self.job_manager.set_launcher(self._launcher)
self.job_manager.add_job_onstart_callback(track_started)
self.job_manager.add_job_onstop_callback(track_completed)
self.job_manager.add_job_onstep_callback(track_timestep)

self._jm.start()
self.job_manager.start()

@property
def launcher(self) -> Launcher:
Expand Down Expand Up @@ -405,7 +405,7 @@ def process_manifest(self, manifest_path: str) -> None:

self.set_launcher(manifest.launcher)

if not self._jm._launcher: # pylint: disable=protected-access
if not self.job_manager._launcher: # pylint: disable=protected-access
raise SmartSimError(f"Unable to set launcher from {manifest_path}")

runs = [run for run in manifest.runs if run.timestamp not in self._tracked_runs]
Expand All @@ -431,14 +431,14 @@ def process_manifest(self, manifest_path: str) -> None:
self._logger,
)

self._jm.add_job(
self.job_manager.add_job(
entity.name,
entity.job_id,
entity,
entity.is_managed,
# is_orch=entity.is_db,
)
self._jm._launcher.step_mapping.add( # pylint: disable=protected-access
self.job_manager._launcher.step_mapping.add( # pylint: disable=protected-access
entity.name, entity.step_id, entity.step_id, entity.is_managed
)
self._tracked_runs[run.timestamp] = run
Expand Down Expand Up @@ -478,8 +478,8 @@ def _to_completed(
if entity.key not in self._completed_jobs:
self._completed_jobs[entity.key] = inactive_entity

job = self._jm[entity.name]
self._jm.move_to_completed(job)
job = self.job_manager[entity.name]
self.job_manager.move_to_completed(job)

if step_info:
detail = f"status: {step_info.status}, error: {step_info.error}"
Expand Down Expand Up @@ -515,6 +515,18 @@ def on_timestep(self, timestamp: int, exp_dir: pathlib.Path) -> None:
self._to_completed(timestamp, completed_entity, exp_dir, step_info)


def shutdown_when_completed(
observer: BaseObserver, action_handler: ManifestEventHandler
) -> None:
"""Inspect active and completed job queues and shutdown if all jobs are complete"""
has_running_jobs = (
not action_handler.job_manager.jobs and not action_handler.job_manager.db_jobs
)
has_completed_jobs = action_handler.job_manager.completed
if not has_running_jobs and has_completed_jobs:
observer.stop() # type: ignore[no-untyped-call]


def event_loop(
observer: BaseObserver,
action_handler: ManifestEventHandler,
Expand All @@ -523,6 +535,7 @@ def event_loop(
num_iters: int,
logger: logging.Logger,
) -> None:
"""Executes all attached timestep handlers every <frequency> seconds"""
num_iters = num_iters if num_iters > 0 else 0 # ensure non-negative limits
remaining = num_iters if num_iters else 0 # track completed iterations

Expand All @@ -536,6 +549,8 @@ def event_loop(
if num_iters and not remaining:
break

shutdown_when_completed(observer, action_handler)


def main(
frequency: t.Union[int, float],
Expand Down

0 comments on commit edda5d7

Please sign in to comment.