Skip to content

Commit

Permalink
add dask client, use it for scheduler.create_job
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-i committed Jun 21, 2024
1 parent 8eaba42 commit 558e576
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 21 deletions.
31 changes: 27 additions & 4 deletions jupyter_scheduler/extension.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio

from dask.distributed import Client as DaskClient
from jupyter_core.paths import jupyter_data_dir
from jupyter_server.extension.application import ExtensionApp
from jupyter_server.transutils import _i18n
Expand Down Expand Up @@ -73,11 +72,15 @@ def initialize_settings(self):

environments_manager = self.environment_manager_class()

asyncio_loop = self.serverapp.io_loop.asyncio_loop
dask_client_future = asyncio_loop.create_task(self._get_dask_client())

scheduler = self.scheduler_class(
root_dir=self.serverapp.root_dir,
environments_manager=environments_manager,
db_url=self.db_url,
config=self.config,
dask_client_future=dask_client_future,
)

job_files_manager = self.job_files_manager_class(scheduler=scheduler)
Expand All @@ -86,8 +89,28 @@ def initialize_settings(self):
environments_manager=environments_manager,
scheduler=scheduler,
job_files_manager=job_files_manager,
dask_client_future=dask_client_future,
)

if scheduler.task_runner:
loop = asyncio.get_event_loop()
loop.create_task(scheduler.task_runner.start())
asyncio_loop.create_task(scheduler.task_runner.start())

async def _get_dask_client(self):
"""Creates and configures a Dask client."""
return DaskClient(processes=False, asynchronous=True)

async def stop_extension(self):
"""Called by the Jupyter Server when stopping to cleanup resources."""
try:
await self._stop_extension()
except Exception as e:
self.log.error("Error while stopping Jupyter Scheduler:")
self.log.exception(e)

async def _stop_extension(self):
"""Closes the Dask client if it exists."""
if "dask_client_future" in self.settings:
dask_client: DaskClient = await self.settings["dask_client_future"]
self.log.info("Closing Dask client.")
await dask_client.close()
self.log.info("Dask client closed.")
35 changes: 18 additions & 17 deletions jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import os
import random
import shutil
from typing import Dict, List, Optional, Type, Union
from typing import Awaitable, Dict, List, Optional, Type, Union

import fsspec
import psutil
from dask.distributed import Client as DaskClient
from jupyter_core.paths import jupyter_data_dir
from jupyter_server.transutils import _i18n
from jupyter_server.utils import to_os_path
Expand Down Expand Up @@ -96,11 +97,17 @@ def _default_staging_path(self):
)

def __init__(
self, root_dir: str, environments_manager: Type[EnvironmentManager], config=None, **kwargs
self,
root_dir: str,
environments_manager: Type[EnvironmentManager],
dask_client_future: Awaitable[DaskClient],
config=None,
**kwargs,
):
super().__init__(config=config, **kwargs)
self.root_dir = root_dir
self.environments_manager = environments_manager
self.dask_client_future = dask_client_future

def create_job(self, model: CreateJob) -> str:
"""Creates a new job record, may trigger execution of the job.
Expand Down Expand Up @@ -437,7 +444,7 @@ def copy_input_folder(self, input_uri: str, nb_copy_to_path: str) -> List[str]:
destination_dir=staging_dir,
)

def create_job(self, model: CreateJob) -> str:
async def create_job(self, model: CreateJob) -> str:
if not model.job_definition_id and not self.file_exists(model.input_uri):
raise InputUriError(model.input_uri)

Expand Down Expand Up @@ -478,25 +485,17 @@ def create_job(self, model: CreateJob) -> str:
else:
self.copy_input_file(model.input_uri, staging_paths["input"])

# The MP context forces new processes to not be forked on Linux.
# This is necessary because `asyncio.get_event_loop()` is bugged in
# forked processes in Python versions below 3.12. This method is
# called by `jupyter_core` by `nbconvert` in the default executor.
#
# See: https://github.com/python/cpython/issues/66285
# See also: https://github.com/jupyter/jupyter_core/pull/362
mp_ctx = mp.get_context("spawn")
p = mp_ctx.Process(
target=self.execution_manager_class(
dask_client: DaskClient = await self.dask_client_future
future = dask_client.submit(
self.execution_manager_class(
job_id=job.job_id,
staging_paths=staging_paths,
root_dir=self.root_dir,
db_url=self.db_url,
).process
)
p.start()

job.pid = p.pid
job.pid = future.key
session.commit()

job_id = job.job_id
Expand Down Expand Up @@ -749,14 +748,16 @@ def list_job_definitions(self, query: ListJobDefinitionsQuery) -> ListJobDefinit

return list_response

def create_job_from_definition(self, job_definition_id: str, model: CreateJobFromDefinition):
async def create_job_from_definition(
self, job_definition_id: str, model: CreateJobFromDefinition
):
job_id = None
definition = self.get_job_definition(job_definition_id)
if definition:
input_uri = self.get_staging_paths(definition)["input"]
attributes = definition.dict(exclude={"schedule", "timezone"}, exclude_none=True)
attributes = {**attributes, **model.dict(exclude_none=True), "input_uri": input_uri}
job_id = self.create_job(CreateJob(**attributes))
job_id = await self.create_job(CreateJob(**attributes))

return job_id

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"pydantic>=1.10,<3",
"sqlalchemy>=2.0,<3",
"croniter~=1.4",
"dask[distributed]",
"pytz==2023.3",
"fsspec==2023.6.0",
"psutil~=5.9"
Expand Down

0 comments on commit 558e576

Please sign in to comment.