diff --git a/cads_worker/worker.py b/cads_worker/worker.py index 57cf4fd..427795e 100644 --- a/cads_worker/worker.py +++ b/cads_worker/worker.py @@ -1,11 +1,13 @@ +import contextlib import functools import os import socket -import tempfile +from collections.abc import Iterator from typing import Any import cacholote import cads_broker.database +import cads_broker.utils import distributed.worker import structlog from distributed import get_worker @@ -36,6 +38,23 @@ def wrapper(self, *args, session=None, **kwargs): return wrapper +@contextlib.contextmanager +def enter_task_temp_dir() -> Iterator[str]: + old_cwd = os.getcwd() + + worker = get_worker() + key = worker.get_current_task() + task_path = cads_broker.utils.rm_task_path(worker, key) + task_path.mkdir(parents=True) + + os.chdir(task_path) + try: + yield os.getcwd() + finally: + os.chdir(old_cwd) + cads_broker.utils.rm_task_path(worker, key) + + class Context(cacholote.config.Context): def __init__( self, @@ -169,7 +188,7 @@ def submit_workflow( config: dict[str, Any] = {}, form: dict[str, Any] = {}, metadata: dict[str, Any] = {}, -): +) -> None: import cads_adaptors job_id = distributed.worker.thread_state.key # type: ignore @@ -185,7 +204,9 @@ def submit_workflow( message=socket.gethostname(), session=session, ) - system_request = cads_broker.database.get_request(request_uid=job_id, session=session) + system_request = cads_broker.database.get_request( + request_uid=job_id, session=session + ) request = system_request.request_body.get("request", {}) form = system_request.adaptor_properties.form config.update(system_request.adaptor_properties.config) @@ -201,23 +222,20 @@ def submit_workflow( ) adaptor_class = cads_adaptors.get_adaptor_class(entry_point, setup_code) adaptor = adaptor_class(form=form, context=context, **config) - cwd = os.getcwd() - with tempfile.TemporaryDirectory() as tmpdir: - os.chdir(tmpdir) - try: - request = {k: request[k] for k in sorted(request.keys())} - result = cacholote.cacheable(adaptor.retrieve)(request=request) - except Exception as err: - logger.exception(job_id=job_id, event_type="EXCEPTION") - context.add_user_visible_error(f"The job failed with: {err.__class__.__name__}") - context.error(f"{err.__class__.__name__}: {str(err)}") - raise - finally: - os.chdir(cwd) + try: + sorted_request = {k: request[k] for k in sorted(request.keys())} + with enter_task_temp_dir(): + result = cacholote.cacheable(adaptor.retrieve)(request=sorted_request) + except Exception as err: + logger.exception(job_id=job_id, event_type="EXCEPTION") + context.add_user_visible_error(f"The job failed with: {err.__class__.__name__}") + context.error(f"{err.__class__.__name__}: {str(err)}") + raise + fs, _ = cacholote.utils.get_cache_files_fs_dirname() fs.chmod(result.result["args"][0]["file:local_path"], acl="public-read") with context.session_maker() as session: - request = cads_broker.database.set_request_cache_id( + cads_broker.database.set_request_cache_id( request_uid=job_id, cache_id=result.id, session=session,