diff --git a/jupyter_scheduler/extension.py b/jupyter_scheduler/extension.py index 001a9adf..b1819279 100644 --- a/jupyter_scheduler/extension.py +++ b/jupyter_scheduler/extension.py @@ -92,3 +92,25 @@ def initialize_settings(self): if scheduler.task_runner: loop = asyncio.get_event_loop() loop.create_task(scheduler.task_runner.start()) + + async def stop_extension(self): + """ + Public method called by Jupyter Server when the server is stopping. + This calls the cleanup code defined in `self._stop_exception()` inside + an exception handler, as the server halts if this method raises an + exception. + """ + try: + await self._stop_extension() + except Exception as e: + self.log.error("Jupyter Scheduler raised an exception while stopping:") + self.log.exception(e) + + async def _stop_extension(self): + """ + Private method that defines the cleanup code to run when the server is + stopping. + """ + if "scheduler" in self.settings: + scheduler: SchedulerApp = self.settings["scheduler"] + await scheduler.stop_extension() diff --git a/jupyter_scheduler/scheduler.py b/jupyter_scheduler/scheduler.py index e574df80..a40a6ea8 100644 --- a/jupyter_scheduler/scheduler.py +++ b/jupyter_scheduler/scheduler.py @@ -98,23 +98,12 @@ def _default_staging_path(self): ) def __init__( - self, - root_dir: str, - environments_manager: Type[EnvironmentManager], - config=None, - **kwargs, + self, root_dir: str, environments_manager: Type[EnvironmentManager], config=None, **kwargs ): super().__init__(config=config, **kwargs) self.root_dir = root_dir self.environments_manager = environments_manager - loop = asyncio.get_event_loop() - self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client()) - - async def _get_dask_client(self): - """Creates and configures a Dask client.""" - return DaskClient(processes=False, asynchronous=True) - def create_job(self, model: CreateJob) -> str: """Creates a new job record, may trigger execution of the job. In case a task runner is actually handling execution of the jobs, @@ -394,6 +383,12 @@ def get_local_output_path( else: return os.path.join(self.root_dir, self.output_directory, output_dir_name) + async def stop_extension(self): + """ + Placeholder method for a cleanup code to run when the server is stopping. + """ + pass + class Scheduler(BaseScheduler): _db_session = None @@ -427,6 +422,13 @@ def __init__( if self.task_runner_class: self.task_runner = self.task_runner_class(scheduler=self, config=config) + loop = asyncio.get_event_loop() + self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client()) + + async def _get_dask_client(self): + """Creates and configures a Dask client.""" + return DaskClient(processes=False, asynchronous=True) + @property def db_session(self): if not self._db_session: @@ -775,6 +777,14 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) -> return staging_paths + async def stop_extension(self): + """ + Cleanup code to run when the server is stopping. + """ + if self.dask_client_future: + dask_client: DaskClient = await self.dask_client_future + await dask_client.close() + class ArchivingScheduler(Scheduler): """Scheduler that captures all files in output directory in an archive."""