diff --git a/jupyter_scheduler/extension.py b/jupyter_scheduler/extension.py index 1a4ba373..b1819279 100644 --- a/jupyter_scheduler/extension.py +++ b/jupyter_scheduler/extension.py @@ -1,5 +1,6 @@ import asyncio +from dask.distributed import Client as DaskClient from jupyter_core.paths import jupyter_data_dir from jupyter_server.extension.application import ExtensionApp from jupyter_server.transutils import _i18n @@ -91,3 +92,25 @@ def initialize_settings(self): if scheduler.task_runner: loop = asyncio.get_event_loop() loop.create_task(scheduler.task_runner.start()) + + async def stop_extension(self): + """ + Public method called by Jupyter Server when the server is stopping. + This calls the cleanup code defined in `self._stop_exception()` inside + an exception handler, as the server halts if this method raises an + exception. + """ + try: + await self._stop_extension() + except Exception as e: + self.log.error("Jupyter Scheduler raised an exception while stopping:") + self.log.exception(e) + + async def _stop_extension(self): + """ + Private method that defines the cleanup code to run when the server is + stopping. + """ + if "scheduler" in self.settings: + scheduler: SchedulerApp = self.settings["scheduler"] + await scheduler.stop_extension() diff --git a/jupyter_scheduler/job_files_manager.py b/jupyter_scheduler/job_files_manager.py index 0e39c2b7..384bcbbd 100644 --- a/jupyter_scheduler/job_files_manager.py +++ b/jupyter_scheduler/job_files_manager.py @@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Type import fsspec +from dask.distributed import Client as DaskClient from jupyter_server.utils import ensure_async from jupyter_scheduler.exceptions import SchedulerError @@ -23,17 +24,20 @@ async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = Fals output_filenames = self.scheduler.get_job_filenames(job) output_dir = self.scheduler.get_local_output_path(model=job, root_dir_relative=True) - p = Process( - target=Downloader( - output_formats=job.output_formats, - output_filenames=output_filenames, - staging_paths=staging_paths, - output_dir=output_dir, - redownload=redownload, - include_staging_files=job.package_input_folder, - ).download - ) - p.start() + download = Downloader( + output_formats=job.output_formats, + output_filenames=output_filenames, + staging_paths=staging_paths, + output_dir=output_dir, + redownload=redownload, + include_staging_files=job.package_input_folder, + ).download + if self.scheduler.dask_client: + dask_client: DaskClient = self.scheduler.dask_client + dask_client.submit(download) + else: + p = Process(target=download) + p.start() class Downloader: diff --git a/jupyter_scheduler/scheduler.py b/jupyter_scheduler/scheduler.py index 867034c6..3ec33aab 100644 --- a/jupyter_scheduler/scheduler.py +++ b/jupyter_scheduler/scheduler.py @@ -1,4 +1,4 @@ -import multiprocessing as mp +import asyncio import os import random import shutil @@ -6,6 +6,8 @@ import fsspec import psutil +from dask.distributed import Client as DaskClient +from distributed import LocalCluster from jupyter_core.paths import jupyter_data_dir from jupyter_server.transutils import _i18n from jupyter_server.utils import to_os_path @@ -381,6 +383,12 @@ def get_local_output_path( else: return os.path.join(self.root_dir, self.output_directory, output_dir_name) + async def stop_extension(self): + """ + Placeholder method for a cleanup code to run when the server is stopping. + """ + pass + class Scheduler(BaseScheduler): _db_session = None @@ -395,6 +403,12 @@ class Scheduler(BaseScheduler): ), ) + dask_cluster_url = Unicode( + allow_none=True, + config=True, + help="URL of the Dask cluster to connect to.", + ) + db_url = Unicode(help=_i18n("Scheduler database url")) task_runner = Instance(allow_none=True, klass="jupyter_scheduler.task_runner.BaseTaskRunner") @@ -414,6 +428,15 @@ def __init__( if self.task_runner_class: self.task_runner = self.task_runner_class(scheduler=self, config=config) + self.dask_client: DaskClient = self._get_dask_client() + + def _get_dask_client(self): + """Creates and configures a Dask client.""" + if self.dask_cluster_url: + return DaskClient(self.dask_cluster_url) + cluster = LocalCluster(processes=True) + return DaskClient(cluster) + @property def db_session(self): if not self._db_session: @@ -478,25 +501,16 @@ def create_job(self, model: CreateJob) -> str: else: self.copy_input_file(model.input_uri, staging_paths["input"]) - # The MP context forces new processes to not be forked on Linux. - # This is necessary because `asyncio.get_event_loop()` is bugged in - # forked processes in Python versions below 3.12. This method is - # called by `jupyter_core` by `nbconvert` in the default executor. - # - # See: https://github.com/python/cpython/issues/66285 - # See also: https://github.com/jupyter/jupyter_core/pull/362 - mp_ctx = mp.get_context("spawn") - p = mp_ctx.Process( - target=self.execution_manager_class( + future = self.dask_client.submit( + self.execution_manager_class( job_id=job.job_id, staging_paths=staging_paths, root_dir=self.root_dir, db_url=self.db_url, ).process ) - p.start() - job.pid = p.pid + job.pid = future.key session.commit() job_id = job.job_id @@ -777,6 +791,13 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) -> return staging_paths + async def stop_extension(self): + """ + Cleanup code to run when the server is stopping. + """ + if self.dask_client: + await self.dask_client.close() + class ArchivingScheduler(Scheduler): """Scheduler that captures all files in output directory in an archive.""" diff --git a/jupyter_scheduler/tests/test_job_files_manager.py b/jupyter_scheduler/tests/test_job_files_manager.py index 66a9727b..52c0564c 100644 --- a/jupyter_scheduler/tests/test_job_files_manager.py +++ b/jupyter_scheduler/tests/test_job_files_manager.py @@ -1,10 +1,9 @@ +import asyncio import filecmp import os import shutil import tarfile -import time -from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -41,23 +40,25 @@ async def test_copy_from_staging(): } output_dir = "jobs/1" with patch("jupyter_scheduler.job_files_manager.Downloader") as mock_downloader: - with patch("jupyter_scheduler.job_files_manager.Process") as mock_process: - with patch("jupyter_scheduler.scheduler.Scheduler") as mock_scheduler: - mock_scheduler.get_job.return_value = job - mock_scheduler.get_staging_paths.return_value = staging_paths - mock_scheduler.get_local_output_path.return_value = output_dir - mock_scheduler.get_job_filenames.return_value = job_filenames - manager = JobFilesManager(scheduler=mock_scheduler) - await manager.copy_from_staging(1) - - mock_downloader.assert_called_once_with( - output_formats=job.output_formats, - output_filenames=job_filenames, - staging_paths=staging_paths, - output_dir=output_dir, - redownload=False, - include_staging_files=None, - ) + with patch("jupyter_scheduler.scheduler.Scheduler") as mock_scheduler: + mock_future = asyncio.Future() + mock_future.set_result(MagicMock()) + mock_scheduler.dask_client_future = mock_future + mock_scheduler.get_job.return_value = job + mock_scheduler.get_staging_paths.return_value = staging_paths + mock_scheduler.get_local_output_path.return_value = output_dir + mock_scheduler.get_job_filenames.return_value = job_filenames + manager = JobFilesManager(scheduler=mock_scheduler) + await manager.copy_from_staging(1) + + mock_downloader.assert_called_once_with( + output_formats=job.output_formats, + output_filenames=job_filenames, + staging_paths=staging_paths, + output_dir=output_dir, + redownload=False, + include_staging_files=None, + ) @pytest.fixture diff --git a/pyproject.toml b/pyproject.toml index 2ae7b947..f5def93c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "pydantic>=1.10,<3", "sqlalchemy>=2.0,<3", "croniter~=1.4", + "dask[distributed]", "pytz==2023.3", "fsspec==2023.6.0", "psutil~=5.9" diff --git a/ui-tests/tests/jupyter_scheduler.spec.ts-snapshots/list-view-linux.png b/ui-tests/tests/jupyter_scheduler.spec.ts-snapshots/list-view-linux.png index bd7180c3..d3b01a4a 100644 Binary files a/ui-tests/tests/jupyter_scheduler.spec.ts-snapshots/list-view-linux.png and b/ui-tests/tests/jupyter_scheduler.spec.ts-snapshots/list-view-linux.png differ