From 558e57607d34323d2c1bbebc9517de81e9a28d62 Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Fri, 21 Jun 2024 10:50:43 -0700 Subject: [PATCH] add dask client, use it for scheduler.create_job --- jupyter_scheduler/extension.py | 31 ++++++++++++++++++++++++++---- jupyter_scheduler/scheduler.py | 35 +++++++++++++++++----------------- pyproject.toml | 1 + 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/jupyter_scheduler/extension.py b/jupyter_scheduler/extension.py index 1a4ba3736..3fc323361 100644 --- a/jupyter_scheduler/extension.py +++ b/jupyter_scheduler/extension.py @@ -1,5 +1,4 @@ -import asyncio - +from dask.distributed import Client as DaskClient from jupyter_core.paths import jupyter_data_dir from jupyter_server.extension.application import ExtensionApp from jupyter_server.transutils import _i18n @@ -73,11 +72,15 @@ def initialize_settings(self): environments_manager = self.environment_manager_class() + asyncio_loop = self.serverapp.io_loop.asyncio_loop + dask_client_future = asyncio_loop.create_task(self._get_dask_client()) + scheduler = self.scheduler_class( root_dir=self.serverapp.root_dir, environments_manager=environments_manager, db_url=self.db_url, config=self.config, + dask_client_future=dask_client_future, ) job_files_manager = self.job_files_manager_class(scheduler=scheduler) @@ -86,8 +89,28 @@ def initialize_settings(self): environments_manager=environments_manager, scheduler=scheduler, job_files_manager=job_files_manager, + dask_client_future=dask_client_future, ) if scheduler.task_runner: - loop = asyncio.get_event_loop() - loop.create_task(scheduler.task_runner.start()) + asyncio_loop.create_task(scheduler.task_runner.start()) + + async def _get_dask_client(self): + """Creates and configures a Dask client.""" + return DaskClient(processes=False, asynchronous=True) + + async def stop_extension(self): + """Called by the Jupyter Server when stopping to cleanup resources.""" + try: + await self._stop_extension() + except Exception as e: + self.log.error("Error while stopping Jupyter Scheduler:") + self.log.exception(e) + + async def _stop_extension(self): + """Closes the Dask client if it exists.""" + if "dask_client_future" in self.settings: + dask_client: DaskClient = await self.settings["dask_client_future"] + self.log.info("Closing Dask client.") + await dask_client.close() + self.log.info("Dask client closed.") diff --git a/jupyter_scheduler/scheduler.py b/jupyter_scheduler/scheduler.py index 867034c60..edbc3740e 100644 --- a/jupyter_scheduler/scheduler.py +++ b/jupyter_scheduler/scheduler.py @@ -2,10 +2,11 @@ import os import random import shutil -from typing import Dict, List, Optional, Type, Union +from typing import Awaitable, Dict, List, Optional, Type, Union import fsspec import psutil +from dask.distributed import Client as DaskClient from jupyter_core.paths import jupyter_data_dir from jupyter_server.transutils import _i18n from jupyter_server.utils import to_os_path @@ -96,11 +97,17 @@ def _default_staging_path(self): ) def __init__( - self, root_dir: str, environments_manager: Type[EnvironmentManager], config=None, **kwargs + self, + root_dir: str, + environments_manager: Type[EnvironmentManager], + dask_client_future: Awaitable[DaskClient], + config=None, + **kwargs, ): super().__init__(config=config, **kwargs) self.root_dir = root_dir self.environments_manager = environments_manager + self.dask_client_future = dask_client_future def create_job(self, model: CreateJob) -> str: """Creates a new job record, may trigger execution of the job. @@ -437,7 +444,7 @@ def copy_input_folder(self, input_uri: str, nb_copy_to_path: str) -> List[str]: destination_dir=staging_dir, ) - def create_job(self, model: CreateJob) -> str: + async def create_job(self, model: CreateJob) -> str: if not model.job_definition_id and not self.file_exists(model.input_uri): raise InputUriError(model.input_uri) @@ -478,25 +485,17 @@ def create_job(self, model: CreateJob) -> str: else: self.copy_input_file(model.input_uri, staging_paths["input"]) - # The MP context forces new processes to not be forked on Linux. - # This is necessary because `asyncio.get_event_loop()` is bugged in - # forked processes in Python versions below 3.12. This method is - # called by `jupyter_core` by `nbconvert` in the default executor. - # - # See: https://github.com/python/cpython/issues/66285 - # See also: https://github.com/jupyter/jupyter_core/pull/362 - mp_ctx = mp.get_context("spawn") - p = mp_ctx.Process( - target=self.execution_manager_class( + dask_client: DaskClient = await self.dask_client_future + future = dask_client.submit( + self.execution_manager_class( job_id=job.job_id, staging_paths=staging_paths, root_dir=self.root_dir, db_url=self.db_url, ).process ) - p.start() - job.pid = p.pid + job.pid = future.key session.commit() job_id = job.job_id @@ -749,14 +748,16 @@ def list_job_definitions(self, query: ListJobDefinitionsQuery) -> ListJobDefinit return list_response - def create_job_from_definition(self, job_definition_id: str, model: CreateJobFromDefinition): + async def create_job_from_definition( + self, job_definition_id: str, model: CreateJobFromDefinition + ): job_id = None definition = self.get_job_definition(job_definition_id) if definition: input_uri = self.get_staging_paths(definition)["input"] attributes = definition.dict(exclude={"schedule", "timezone"}, exclude_none=True) attributes = {**attributes, **model.dict(exclude_none=True), "input_uri": input_uri} - job_id = self.create_job(CreateJob(**attributes)) + job_id = await self.create_job(CreateJob(**attributes)) return job_id diff --git a/pyproject.toml b/pyproject.toml index 2ae7b9476..f5def93cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "pydantic>=1.10,<3", "sqlalchemy>=2.0,<3", "croniter~=1.4", + "dask[distributed]", "pytz==2023.3", "fsspec==2023.6.0", "psutil~=5.9"