Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Dask instead of multiprocessing module #530

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
23 changes: 23 additions & 0 deletions jupyter_scheduler/extension.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio

from dask.distributed import Client as DaskClient
from jupyter_core.paths import jupyter_data_dir
from jupyter_server.extension.application import ExtensionApp
from jupyter_server.transutils import _i18n
Expand Down Expand Up @@ -91,3 +92,25 @@ def initialize_settings(self):
if scheduler.task_runner:
loop = asyncio.get_event_loop()
loop.create_task(scheduler.task_runner.start())

async def stop_extension(self):
"""
Public method called by Jupyter Server when the server is stopping.
This calls the cleanup code defined in `self._stop_exception()` inside
an exception handler, as the server halts if this method raises an
exception.
"""
try:
await self._stop_extension()
except Exception as e:
self.log.error("Jupyter Scheduler raised an exception while stopping:")
andrii-i marked this conversation as resolved.
Show resolved Hide resolved
self.log.exception(e)

async def _stop_extension(self):
"""
Private method that defines the cleanup code to run when the server is
stopping.
"""
if "scheduler" in self.settings:
scheduler: SchedulerApp = self.settings["scheduler"]
await scheduler.stop_extension()
26 changes: 15 additions & 11 deletions jupyter_scheduler/job_files_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, List, Optional, Type

import fsspec
from dask.distributed import Client as DaskClient
from jupyter_server.utils import ensure_async

from jupyter_scheduler.exceptions import SchedulerError
Expand All @@ -23,17 +24,20 @@ async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = Fals
output_filenames = self.scheduler.get_job_filenames(job)
output_dir = self.scheduler.get_local_output_path(model=job, root_dir_relative=True)

p = Process(
target=Downloader(
output_formats=job.output_formats,
output_filenames=output_filenames,
staging_paths=staging_paths,
output_dir=output_dir,
redownload=redownload,
include_staging_files=job.package_input_folder,
).download
)
p.start()
download = Downloader(
output_formats=job.output_formats,
output_filenames=output_filenames,
staging_paths=staging_paths,
output_dir=output_dir,
redownload=redownload,
include_staging_files=job.package_input_folder,
).download
if self.scheduler.dask_client:
dask_client: DaskClient = self.scheduler.dask_client
dask_client.submit(download)
else:
p = Process(target=download)
p.start()


class Downloader:
Expand Down
47 changes: 34 additions & 13 deletions jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import multiprocessing as mp
import asyncio
import os
import random
import shutil
from typing import Dict, List, Optional, Type, Union

import fsspec
import psutil
from dask.distributed import Client as DaskClient
from distributed import LocalCluster
from jupyter_core.paths import jupyter_data_dir
from jupyter_server.transutils import _i18n
from jupyter_server.utils import to_os_path
Expand Down Expand Up @@ -381,6 +383,12 @@ def get_local_output_path(
else:
return os.path.join(self.root_dir, self.output_directory, output_dir_name)

async def stop_extension(self):
"""
Placeholder method for a cleanup code to run when the server is stopping.
"""
pass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to keep this method if it does nothing?

Copy link
Collaborator Author

@andrii-i andrii-i Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. BaseScheduler is a base class / klass for schedulers, essentially an interface for all schedulers instantiated and used by SchedulerApp. BaseScheduler.stop_extension creates a placeholder for scheduler cleanup / stop_extension logic that is called from SchedulerApp.stop_extension.
SchedulerApp.stop_extension is a public method inherited from ExtensionApp called by Jupyter Server when the server is stopping.

I implemented stop_extension in default Scheduler in Scheduler.stop_extension to correctly terminate Dask and avoid JupyterLab process not shutting down immediately on interrupt as for example we've seen in Jupyter AI.


class Scheduler(BaseScheduler):
_db_session = None
Expand All @@ -395,6 +403,12 @@ class Scheduler(BaseScheduler):
),
)

dask_cluster_url = Unicode(
allow_none=True,
config=True,
help="URL of the Dask cluster to connect to.",
)

db_url = Unicode(help=_i18n("Scheduler database url"))

task_runner = Instance(allow_none=True, klass="jupyter_scheduler.task_runner.BaseTaskRunner")
Expand All @@ -414,6 +428,15 @@ def __init__(
if self.task_runner_class:
self.task_runner = self.task_runner_class(scheduler=self, config=config)

self.dask_client: DaskClient = self._get_dask_client()

def _get_dask_client(self):
"""Creates and configures a Dask client."""
if self.dask_cluster_url:
return DaskClient(self.dask_cluster_url)
cluster = LocalCluster(processes=True)
return DaskClient(cluster)

@property
def db_session(self):
if not self._db_session:
Expand Down Expand Up @@ -478,25 +501,16 @@ def create_job(self, model: CreateJob) -> str:
else:
self.copy_input_file(model.input_uri, staging_paths["input"])

# The MP context forces new processes to not be forked on Linux.
# This is necessary because `asyncio.get_event_loop()` is bugged in
# forked processes in Python versions below 3.12. This method is
# called by `jupyter_core` by `nbconvert` in the default executor.
#
# See: https://github.com/python/cpython/issues/66285
# See also: https://github.com/jupyter/jupyter_core/pull/362
mp_ctx = mp.get_context("spawn")
p = mp_ctx.Process(
target=self.execution_manager_class(
future = self.dask_client.submit(
self.execution_manager_class(
job_id=job.job_id,
staging_paths=staging_paths,
root_dir=self.root_dir,
db_url=self.db_url,
).process
)
p.start()

job.pid = p.pid
job.pid = future.key
session.commit()

job_id = job.job_id
Expand Down Expand Up @@ -777,6 +791,13 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->

return staging_paths

async def stop_extension(self):
"""
Cleanup code to run when the server is stopping.
"""
if self.dask_client:
await self.dask_client.close()


class ArchivingScheduler(Scheduler):
"""Scheduler that captures all files in output directory in an archive."""
Expand Down
41 changes: 21 additions & 20 deletions jupyter_scheduler/tests/test_job_files_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import asyncio
import filecmp
import os
import shutil
import tarfile
import time
from pathlib import Path
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -41,23 +40,25 @@ async def test_copy_from_staging():
}
output_dir = "jobs/1"
with patch("jupyter_scheduler.job_files_manager.Downloader") as mock_downloader:
with patch("jupyter_scheduler.job_files_manager.Process") as mock_process:
with patch("jupyter_scheduler.scheduler.Scheduler") as mock_scheduler:
mock_scheduler.get_job.return_value = job
mock_scheduler.get_staging_paths.return_value = staging_paths
mock_scheduler.get_local_output_path.return_value = output_dir
mock_scheduler.get_job_filenames.return_value = job_filenames
manager = JobFilesManager(scheduler=mock_scheduler)
await manager.copy_from_staging(1)

mock_downloader.assert_called_once_with(
output_formats=job.output_formats,
output_filenames=job_filenames,
staging_paths=staging_paths,
output_dir=output_dir,
redownload=False,
include_staging_files=None,
)
with patch("jupyter_scheduler.scheduler.Scheduler") as mock_scheduler:
mock_future = asyncio.Future()
mock_future.set_result(MagicMock())
mock_scheduler.dask_client_future = mock_future
mock_scheduler.get_job.return_value = job
mock_scheduler.get_staging_paths.return_value = staging_paths
mock_scheduler.get_local_output_path.return_value = output_dir
mock_scheduler.get_job_filenames.return_value = job_filenames
manager = JobFilesManager(scheduler=mock_scheduler)
await manager.copy_from_staging(1)

mock_downloader.assert_called_once_with(
output_formats=job.output_formats,
output_filenames=job_filenames,
staging_paths=staging_paths,
output_dir=output_dir,
redownload=False,
include_staging_files=None,
)


@pytest.fixture
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"pydantic>=1.10,<3",
"sqlalchemy>=2.0,<3",
"croniter~=1.4",
"dask[distributed]",
"pytz==2023.3",
"fsspec==2023.6.0",
"psutil~=5.9"
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading