Skip to content

Commit

Permalink
add stop_extension logic, use it for stopping dask
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-i committed Jun 25, 2024
1 parent d7c1fec commit d753850
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
22 changes: 22 additions & 0 deletions jupyter_scheduler/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,25 @@ def initialize_settings(self):
if scheduler.task_runner:
loop = asyncio.get_event_loop()
loop.create_task(scheduler.task_runner.start())

async def stop_extension(self):
"""
Public method called by Jupyter Server when the server is stopping.
This calls the cleanup code defined in `self._stop_exception()` inside
an exception handler, as the server halts if this method raises an
exception.
"""
try:
await self._stop_extension()
except Exception as e:
self.log.error("Jupyter Scheduler raised an exception while stopping:")
self.log.exception(e)

async def _stop_extension(self):
"""
Private method that defines the cleanup code to run when the server is
stopping.
"""
if "scheduler" in self.settings:
scheduler: SchedulerApp = self.settings["scheduler"]
await scheduler.stop_extension()
34 changes: 22 additions & 12 deletions jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,12 @@ def _default_staging_path(self):
)

def __init__(
self,
root_dir: str,
environments_manager: Type[EnvironmentManager],
config=None,
**kwargs,
self, root_dir: str, environments_manager: Type[EnvironmentManager], config=None, **kwargs
):
super().__init__(config=config, **kwargs)
self.root_dir = root_dir
self.environments_manager = environments_manager

loop = asyncio.get_event_loop()
self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client())

async def _get_dask_client(self):
"""Creates and configures a Dask client."""
return DaskClient(processes=False, asynchronous=True)

def create_job(self, model: CreateJob) -> str:
"""Creates a new job record, may trigger execution of the job.
In case a task runner is actually handling execution of the jobs,
Expand Down Expand Up @@ -393,6 +382,12 @@ def get_local_output_path(
else:
return os.path.join(self.root_dir, self.output_directory, output_dir_name)

async def stop_extension(self):
"""
Placeholder method for a cleanup code to run when the server is stopping.
"""
pass


class Scheduler(BaseScheduler):
_db_session = None
Expand Down Expand Up @@ -426,6 +421,13 @@ def __init__(
if self.task_runner_class:
self.task_runner = self.task_runner_class(scheduler=self, config=config)

loop = asyncio.get_event_loop()
self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client())

async def _get_dask_client(self):
"""Creates and configures a Dask client."""
return DaskClient(processes=False, asynchronous=True)

@property
def db_session(self):
if not self._db_session:
Expand Down Expand Up @@ -783,6 +785,14 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->

return staging_paths

async def stop_extension(self):
"""
Cleanup code to run when the server is stopping.
"""
if self.dask_client_future:
dask_client: DaskClient = await self.dask_client_future
await dask_client.close()


class ArchivingScheduler(Scheduler):
"""Scheduler that captures all files in output directory in an archive."""
Expand Down

0 comments on commit d753850

Please sign in to comment.