From d75385044500904ea964b17bb4b9ad02df76413d Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Mon, 24 Jun 2024 17:50:16 -0700 Subject: [PATCH] add stop_extension logic, use it for stopping dask --- jupyter_scheduler/extension.py | 22 ++++++++++++++++++++++ jupyter_scheduler/scheduler.py | 34 ++++++++++++++++++++++------------ 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/jupyter_scheduler/extension.py b/jupyter_scheduler/extension.py index 001a9adf..b1819279 100644 --- a/jupyter_scheduler/extension.py +++ b/jupyter_scheduler/extension.py @@ -92,3 +92,25 @@ def initialize_settings(self): if scheduler.task_runner: loop = asyncio.get_event_loop() loop.create_task(scheduler.task_runner.start()) + + async def stop_extension(self): + """ + Public method called by Jupyter Server when the server is stopping. + This calls the cleanup code defined in `self._stop_exception()` inside + an exception handler, as the server halts if this method raises an + exception. + """ + try: + await self._stop_extension() + except Exception as e: + self.log.error("Jupyter Scheduler raised an exception while stopping:") + self.log.exception(e) + + async def _stop_extension(self): + """ + Private method that defines the cleanup code to run when the server is + stopping. + """ + if "scheduler" in self.settings: + scheduler: SchedulerApp = self.settings["scheduler"] + await scheduler.stop_extension() diff --git a/jupyter_scheduler/scheduler.py b/jupyter_scheduler/scheduler.py index f204df73..2ae53a13 100644 --- a/jupyter_scheduler/scheduler.py +++ b/jupyter_scheduler/scheduler.py @@ -97,23 +97,12 @@ def _default_staging_path(self): ) def __init__( - self, - root_dir: str, - environments_manager: Type[EnvironmentManager], - config=None, - **kwargs, + self, root_dir: str, environments_manager: Type[EnvironmentManager], config=None, **kwargs ): super().__init__(config=config, **kwargs) self.root_dir = root_dir self.environments_manager = environments_manager - loop = asyncio.get_event_loop() - self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client()) - - async def _get_dask_client(self): - """Creates and configures a Dask client.""" - return DaskClient(processes=False, asynchronous=True) - def create_job(self, model: CreateJob) -> str: """Creates a new job record, may trigger execution of the job. In case a task runner is actually handling execution of the jobs, @@ -393,6 +382,12 @@ def get_local_output_path( else: return os.path.join(self.root_dir, self.output_directory, output_dir_name) + async def stop_extension(self): + """ + Placeholder method for a cleanup code to run when the server is stopping. + """ + pass + class Scheduler(BaseScheduler): _db_session = None @@ -426,6 +421,13 @@ def __init__( if self.task_runner_class: self.task_runner = self.task_runner_class(scheduler=self, config=config) + loop = asyncio.get_event_loop() + self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client()) + + async def _get_dask_client(self): + """Creates and configures a Dask client.""" + return DaskClient(processes=False, asynchronous=True) + @property def db_session(self): if not self._db_session: @@ -783,6 +785,14 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) -> return staging_paths + async def stop_extension(self): + """ + Cleanup code to run when the server is stopping. + """ + if self.dask_client_future: + dask_client: DaskClient = await self.dask_client_future + await dask_client.close() + class ArchivingScheduler(Scheduler): """Scheduler that captures all files in output directory in an archive."""