From 593888aaf194c41b98ec6284cebb3e0bf1b672a2 Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Mon, 6 Oct 2025 22:09:38 +0200 Subject: [PATCH] initial prototype --- .../aiida_core/engine/__init__.py | 0 .../aiida_core/engine/calcjob/__init__.py | 0 .../aiida_core/engine/calcjob/tasks.py | 767 ++++++++++++++++++ .../aiida_core/transport.py | 17 +- .../example_dags/arithmetic_aiida_native.py | 411 ++++++++++ .../operators/async_aiida_calcjob.py | 326 ++++++++ .../taskgroups/async_aiida_calcjob.py | 465 +++++++++++ .../triggers/async_aiida_calcjob.py | 423 ++++++++++ 8 files changed, 2404 insertions(+), 5 deletions(-) create mode 100644 src/airflow_provider_aiida/aiida_core/engine/__init__.py create mode 100644 src/airflow_provider_aiida/aiida_core/engine/calcjob/__init__.py create mode 100644 src/airflow_provider_aiida/aiida_core/engine/calcjob/tasks.py create mode 100644 src/airflow_provider_aiida/example_dags/arithmetic_aiida_native.py create mode 100644 src/airflow_provider_aiida/operators/async_aiida_calcjob.py create mode 100644 src/airflow_provider_aiida/taskgroups/async_aiida_calcjob.py create mode 100644 src/airflow_provider_aiida/triggers/async_aiida_calcjob.py diff --git a/src/airflow_provider_aiida/aiida_core/engine/__init__.py b/src/airflow_provider_aiida/aiida_core/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/airflow_provider_aiida/aiida_core/engine/calcjob/__init__.py b/src/airflow_provider_aiida/aiida_core/engine/calcjob/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/airflow_provider_aiida/aiida_core/engine/calcjob/tasks.py b/src/airflow_provider_aiida/aiida_core/engine/calcjob/tasks.py new file mode 100644 index 0000000..d7d5777 --- /dev/null +++ b/src/airflow_provider_aiida/aiida_core/engine/calcjob/tasks.py @@ -0,0 +1,767 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Transport tasks for calculation jobs.""" + +from __future__ import annotations + +import asyncio +import functools +import logging +import tempfile +from typing import TYPE_CHECKING, Any, Callable, Optional + +import plumpy +import plumpy.futures +import plumpy.persistence +import plumpy.process_states + +from aiida.common.datastructures import CalcJobState +from aiida.common.exceptions import FeatureNotAvailable, TransportTaskException +from aiida.common.folders import SandboxFolder +from aiida.common.links import LinkType +from aiida.engine import utils +from aiida.engine.daemon import execmanager +from aiida.engine.processes.exit_code import ExitCode +from aiida.engine.transports import TransportQueue +from aiida.engine.utils import InterruptableFuture, interruptable_task +from aiida.manage.configuration import get_config_option +from aiida.orm.nodes.process.calculation.calcjob import CalcJobNode +from aiida.schedulers.datastructures import JobState + +from aiida.engine.processes.process import ProcessState +from aiida.engine.processes.calcjobs.monitors import CalcJobMonitorAction, CalcJobMonitorResult, CalcJobMonitors + +if TYPE_CHECKING: + from aiida.engine.processes.calcjobs.calcjob import CalcJob + +UPLOAD_COMMAND = 'upload' +SUBMIT_COMMAND = 'submit' +UPDATE_COMMAND = 'update' +RETRIEVE_COMMAND = 'retrieve' +STASH_COMMAND = 'stash' +KILL_COMMAND = 'kill' + +RETRY_INTERVAL_OPTION = 'transport.task_retry_initial_interval' +MAX_ATTEMPTS_OPTION = 'transport.task_maximum_attempts' + +logger = logging.getLogger(__name__) + + +class PreSubmitException(Exception): # noqa: N818 + """Raise in the `do_upload` coroutine when an exception is raised in `CalcJob.presubmit`.""" + + +async def task_upload_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture, calc_info=None): + """Transport task that will attempt to upload the files of a job calculation to the remote. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the CalcJobNode + :param transport_queue: the TransportQueue from which to request a Transport (unused, replaced with get_transport_queue) + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + :param calc_info: optional CalcInfo from prepare_for_submission + + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + + if node.get_state() == CalcJobState.UNSTASHING: + logger.warning(f'CalcJob<{node.pk}> already marked as UNSTASHING, skipping task_update_job') + return + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + filepath_sandbox = get_config_option('storage.sandbox') or None + + authinfo = node.get_authinfo() + + async def do_upload(): + # AIRFLOW-AIIDA-MODIFICATION: Get transport queue from provider package instead of parameter + # This ensures we use the current event loop instead of a stale one + from airflow_provider_aiida.aiida_core.transport import get_transport_queue + transport_queue = get_transport_queue() + + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + + with SandboxFolder(filepath_sandbox) as folder: + # Restore folder contents from node's repository + # (they were stored by prepare_for_submission task) + node.base.repository.copy_tree(folder.abspath, path=None) + + + # Use the provided calc_info or create a minimal one + if calc_info is None: + from aiida.common.datastructures import CalcInfo + calc_info_to_use = CalcInfo() + calc_info_to_use.uuid = str(node.uuid) + calc_info_to_use.skip_submit = False + else: + calc_info_to_use = calc_info + + remote_folder = await execmanager.upload_calculation(node, transport, calc_info_to_use, folder) + # AIRFLOW-AIIDA-MODIFICATION: Removed process.out('remote_folder', remote_folder) + # Original AiiDA code stores remote_folder as an output link via process.out() + # which is allowed on stored nodes. We tried to store it as an attribute instead, + # but attributes cannot be modified on stored+unsealed nodes (only _updatable_attributes can). + # Since we don't have process context to create output links, we skip storing remote_folder. + # The remote working directory path is already tracked in the node's internal state. + skip_submit = calc_info_to_use.skip_submit or False + + return skip_submit + + try: + logger.info(f'scheduled request to upload CalcJob<{node.pk}>') + ignore_exceptions = (plumpy.futures.CancelledError, PreSubmitException, plumpy.process_states.Interruption) + skip_submit = await utils.exponential_backoff_retry( + do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) + except PreSubmitException: + raise + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): + raise + except Exception as exception: + logger.warning(f'uploading CalcJob<{node.pk}> failed') + raise TransportTaskException(f'upload_calculation failed {max_attempts} times consecutively') from exception + else: + logger.info(f'uploading CalcJob<{node.pk}> successful') + node.set_state(CalcJobState.UNSTASHING) + return skip_submit + + +async def task_submit_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): + """Transport task that will attempt to submit a job calculation. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + if node.get_state() == CalcJobState.WITHSCHEDULER: + assert node.get_job_id() is not None, 'job is WITHSCHEDULER, however, it does not have a job id' + logger.warning(f'CalcJob<{node.pk}> already marked as WITHSCHEDULER, skipping task_submit_job') + return node.get_job_id() + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + authinfo = node.get_authinfo() + + async def do_submit(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + return execmanager.submit_calculation(node, transport) + + try: + logger.info(f'scheduled request to submit CalcJob<{node.pk}>') + ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) + result = await utils.exponential_backoff_retry( + do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): + raise + except Exception as exception: + logger.warning(f'submitting CalcJob<{node.pk}> failed') + raise TransportTaskException(f'submit_calculation failed {max_attempts} times consecutively') from exception + else: + logger.info(f'submitting CalcJob<{node.pk}> successful') + node.set_state(CalcJobState.WITHSCHEDULER) + return result + + +async def task_update_job(node: CalcJobNode, job_manager, cancellable: InterruptableFuture): + """Transport task that will attempt to update the scheduler status of the job calculation. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param job_manager: The job manager + :param cancellable: A cancel flag + :return: True if the tasks was successfully completed, False otherwise + """ + state = node.get_state() + + if state in [CalcJobState.RETRIEVING, CalcJobState.STASHING]: + logger.warning(f'CalcJob<{node.pk}> already marked as `{state}`, skipping task_update_job') + return True + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + authinfo = node.get_authinfo() + job_id = node.get_job_id() + + async def do_update(): + # Get the update request + with job_manager.request_job_info_update(authinfo, job_id) as update_request: + job_info = await cancellable.with_interrupt(update_request) + + if job_info is None: + # If the job is computed or not found assume it's done + node.set_scheduler_state(JobState.DONE) + job_done = True + else: + node.set_last_job_info(job_info) + node.set_scheduler_state(job_info.job_state) + job_done = job_info.job_state == JobState.DONE + + return job_done + + try: + logger.info(f'scheduled request to update CalcJob<{node.pk}>') + ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) + job_done = await utils.exponential_backoff_retry( + do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): + raise + except Exception as exception: + logger.warning(f'updating CalcJob<{node.pk}> failed') + raise TransportTaskException(f'update_calculation failed {max_attempts} times consecutively') from exception + else: + logger.info(f'updating CalcJob<{node.pk}> successful') + if job_done: + node.set_state(CalcJobState.STASHING) + + return job_done + + +async def task_monitor_job( + node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture, monitors: CalcJobMonitors +): + """Transport task that will monitor the job calculation if any monitors have been defined. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param cancellable: A cancel flag + :param monitors: An instance of ``CalcJobMonitors`` holding the collection of monitors to process. + :return: True if the tasks was successfully completed, False otherwise + """ + state = node.get_state() + + if state in [CalcJobState.RETRIEVING, CalcJobState.STASHING]: + logger.warning(f'CalcJob<{node.pk}> already marked as `{state}`, skipping task_monitor_job') + return None + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + authinfo = node.get_authinfo() + + async def do_monitor(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + return monitors.process(node, transport) + + try: + logger.info(f'scheduled request to monitor CalcJob<{node.pk}>') + ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) + monitor_result = await utils.exponential_backoff_retry( + do_monitor, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): + raise + except Exception as exception: + logger.warning(f'monitoring CalcJob<{node.pk}> failed') + raise TransportTaskException(f'monitor_calculation failed {max_attempts} times consecutively') from exception + else: + logger.info(f'monitoring CalcJob<{node.pk}> successful') + return monitor_result + + +async def task_retrieve_job( + node: CalcJobNode, + transport_queue: TransportQueue, + retrieved_temporary_folder: str, + cancellable: InterruptableFuture, +): + """Transport task that will attempt to retrieve all files of a completed job calculation. + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + :param node: the CalcJobNode + :param transport_queue: the TransportQueue from which to request a Transport + :param retrieved_temporary_folder: the absolute path to a directory to store files + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + if node.get_state() == CalcJobState.PARSING: + logger.warning(f'CalcJob<{node.pk}> already marked as PARSING, skipping task_retrieve_job') + return + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + authinfo = node.get_authinfo() + + async def do_retrieve(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + # Perform the job accounting and set it on the node if successful. If the scheduler does not implement this + # still set the attribute but set it to `None`. This way we can distinguish calculation jobs for which the + # accounting was called but could not be set. + scheduler = node.computer.get_scheduler() # type: ignore[union-attr] + scheduler.set_transport(transport) + + if node.get_job_id() is None: + logger.warning(f'there is no job id for CalcJobNoe<{node.pk}>: skipping `get_detailed_job_info`') + retrieved = await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) + else: + try: + detailed_job_info = scheduler.get_detailed_job_info(node.get_job_id()) + except FeatureNotAvailable: + logger.info(f'detailed job info not available for scheduler of CalcJob<{node.pk}>') + node.set_detailed_job_info(None) + else: + node.set_detailed_job_info(detailed_job_info) + + retrieved = await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder) + + # AIRFLOW-AIIDA-MODIFICATION: Manually create output link instead of process.out() + # Original AiiDA code: process.out(node.link_label_retrieved, retrieved) + # Since we don't have process context, we manually create the output link + if retrieved is not None: + retrieved.base.links.add_incoming(node, link_type=LinkType.CREATE, link_label=node.link_label_retrieved) + retrieved.store() + + return retrieved + + try: + logger.info(f'scheduled request to retrieve CalcJob<{node.pk}>') + ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) + result = await utils.exponential_backoff_retry( + do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions + ) + except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): + raise + except Exception as exception: + logger.warning(f'retrieving CalcJob<{node.pk}> failed') + raise TransportTaskException(f'retrieve_calculation failed {max_attempts} times consecutively') from exception + else: + node.set_state(CalcJobState.PARSING) + logger.info(f'retrieving CalcJob<{node.pk}> successful') + return result + + +async def task_stash_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): + """Transport task that will optionally stash files of a completed job calculation on the remote. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + :raises: Return if the tasks was successfully completed + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + if node.get_state() == CalcJobState.RETRIEVING: + logger.warning(f'calculation<{node.pk}> already marked as RETRIEVING, skipping task_stash_job') + return + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + authinfo = node.get_authinfo() + + async def do_stash(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + + logger.info(f'stashing calculation<{node.pk}>') + return await execmanager.stash_calculation(node, transport) + + try: + await utils.exponential_backoff_retry( + do_stash, + initial_interval, + max_attempts, + logger=node.logger, + ignore_exceptions=plumpy.process_states.Interruption, + ) + except plumpy.process_states.Interruption: + raise + except Exception as exception: + logger.warning(f'stashing calculation<{node.pk}> failed') + raise TransportTaskException(f'stash_calculation failed {max_attempts} times consecutively') from exception + else: + node.set_state(CalcJobState.RETRIEVING) + logger.info(f'stashing calculation<{node.pk}> successful') + return + + +async def task_unstash_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): + if node.get_state() == CalcJobState.SUBMITTING: + logger.warning(f'CalcJob<{node.pk}> already marked as SUBMITTING, skipping task_update_job') + return + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + authinfo = node.get_authinfo() + + async def do_unstash(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + + logger.info(f'unstashing calculation<{node.pk}>') + return await execmanager.unstash_calculation(node, transport) + + try: + await utils.exponential_backoff_retry( + do_unstash, + initial_interval, + max_attempts, + logger=node.logger, + ignore_exceptions=plumpy.process_states.Interruption, + ) + except plumpy.process_states.Interruption: + raise + except Exception as exception: + logger.warning(f'unstashing calculation<{node.pk}> failed') + raise TransportTaskException(f'unstash_calculation failed {max_attempts} times consecutively') from exception + else: + node.set_state(CalcJobState.SUBMITTING) + logger.info(f'unstashing calculation<{node.pk}> successful') + return + + +async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): + """Transport task that will attempt to kill a job calculation. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + if node.get_state() in [CalcJobState.UPLOADING, CalcJobState.SUBMITTING]: + logger.warning(f'CalcJob<{node.pk}> killed, it was in the {node.get_state()} state') + return True + + authinfo = node.get_authinfo() + + async def do_kill(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + return execmanager.kill_calculation(node, transport) + + try: + logger.info(f'scheduled request to kill CalcJob<{node.pk}>') + result = await utils.exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger) + except plumpy.process_states.Interruption: + raise + except Exception as exception: + logger.warning(f'killing CalcJob<{node.pk}> failed') + raise TransportTaskException(f'kill_calculation failed {max_attempts} times consecutively') from exception + else: + logger.info(f'killing CalcJob<{node.pk}> successful') + node.set_scheduler_state(JobState.DONE) + return result + + +@plumpy.persistence.auto_persist('msg', 'data', '_command', '_monitor_result') +class Waiting(plumpy.process_states.Waiting): + """The waiting state for the `CalcJob` process.""" + + def __init__( + self, + process: 'CalcJob', + done_callback: Optional[Callable[..., Any]], + msg: Optional[str] = None, + data: Optional[Any] = None, + ): + """:param process: The process this state belongs to""" + super().__init__(process, done_callback, msg, data) + self._task: InterruptableFuture | None = None + self._killing: plumpy.futures.Future | None = None + self._command: Callable[..., Any] | None = None + self._monitor_result: CalcJobMonitorResult | None = None + self._monitors: CalcJobMonitors | None = None + + if isinstance(self.data, dict): + self._command = self.data['command'] + self._monitor_result = self.data.get('monitor_result', None) + else: + self._command = self.data + + @property + def monitors(self) -> CalcJobMonitors | None: + """Return the collection of monitors if specified in the inputs. + + :return: Instance of ``CalcJobMonitors`` containing monitors if specified in the process' input. + """ + if not hasattr(self, '_monitors'): + self._monitors = None + + if self._monitors is None and 'monitors' in self.process.node.inputs: + self._monitors = CalcJobMonitors(self.process.node.inputs.monitors) + + return self._monitors + + @property + def process(self) -> 'CalcJob': + """:return: The process""" + return self.state_machine # type: ignore[return-value] + + def load_instance_state(self, saved_state, load_context): + super().load_instance_state(saved_state, load_context) + self._task = None + self._killing = None + + async def execute(self) -> plumpy.process_states.State: # type: ignore[override] + """Override the execute coroutine of the base `Waiting` state. + Using the plumpy state machine the waiting state is repeatedly re-entered with different commands. + The waiting state is not always the same instance, it could be re-instantiated when re-entering this method, + therefor any newly created attribute in each command block + (e.g. `SUBMIT_COMMAND`, `UPLOAD_COMMAND`, etc.) will be lost, and is not usable in other blocks. + The advantage of this design, is that the sequence is interruptable, + meaning, the process can potentially come back and start from where it left off. + + The overall sequence is as follows: + in case `skip_submit` is True: + + UPLOAD -> STASH -> RETRIEVE + | ^ | ^ | ^ + v | v | v | + .. .. .. .. .. .. + + otherwise: + + UPLOAD -> SUBMIT -> UPDATE -> STASH -> RETRIEVE + | ^ | ^ | ^ | ^ | ^ + v | v | v | v | v | + .. .. .. .. .. .. .. .. .. .. + """ + + node = self.process.node + transport_queue = self.process.runner.transport + result: plumpy.process_states.State = self + + process_status = f'Waiting for transport task: {self._command}' + node.set_process_status(process_status) + + try: + if self._command == UPLOAD_COMMAND: + skip_submit = await self._launch_task(task_upload_job, self.process, transport_queue) + # Note: we do both `task_upload_job` and `task_unstash_job` at the same time, + # only because `skip_submit` is not easily accesible outside this `if` block! + if node.get_option('unstash') and node.process_type == 'aiida.calculations:core.unstash': + await self._launch_task(task_unstash_job, node, transport_queue) + if skip_submit: + result = self.stash(monitor_result=self._monitor_result) + else: + result = self.submit() + + elif self._command == SUBMIT_COMMAND: + result = await self._launch_task(task_submit_job, node, transport_queue) + + if isinstance(result, ExitCode): + # The scheduler plugin returned an exit code from ``Scheduler.submit_job`` indicating the + # job submission failed due to a non-transient problem and the job should be terminated. + return self.create_state(ProcessState.RUNNING, self.process.terminate, result) + + result = self.update() + + elif self._command == UPDATE_COMMAND: + job_done = False + + while not job_done: + scheduler_state = node.get_scheduler_state() + scheduler_state_string = scheduler_state.name if scheduler_state else 'UNKNOWN' + process_status = f'Monitoring scheduler: job state {scheduler_state_string}' + node.set_process_status(process_status) + job_done = await self._launch_task(task_update_job, node, self.process.runner.job_manager) + monitor_result = await self._monitor_job(node, transport_queue, self.monitors) + + if monitor_result and monitor_result.action is CalcJobMonitorAction.KILL: + await self._kill_job(node, transport_queue) + job_done = True + + if monitor_result and not monitor_result.retrieve: + exit_code = self.process.exit_codes.STOPPED_BY_MONITOR.format(message=monitor_result.message) + return self.create_state(ProcessState.RUNNING, self.process.terminate, exit_code) # type: ignore[return-value] + + result = self.stash(monitor_result=monitor_result) + + elif self._command == STASH_COMMAND: + if node.get_option('stash'): + await self._launch_task(task_stash_job, node, transport_queue) + result = self.retrieve(monitor_result=self._monitor_result) + + elif self._command == RETRIEVE_COMMAND: + temp_folder = tempfile.mkdtemp() + await self._launch_task(task_retrieve_job, self.process, transport_queue, temp_folder) + + if not self._monitor_result: + result = self.parse(temp_folder) + + elif self._monitor_result.parse is False: + exit_code = self.process.exit_codes.STOPPED_BY_MONITOR.format(message=self._monitor_result.message) + result = self.create_state( # type: ignore[assignment] + ProcessState.RUNNING, self.process.terminate, exit_code + ) + + elif self._monitor_result.override_exit_code: + exit_code = self.process.exit_codes.STOPPED_BY_MONITOR.format(message=self._monitor_result.message) + result = self.parse(temp_folder, exit_code) + else: + result = self.parse(temp_folder) + + else: + raise RuntimeError('Unknown waiting command') + + except TransportTaskException as exception: + raise plumpy.process_states.PauseInterruption(f'Pausing after failed transport task: {exception}') + except plumpy.process_states.KillInterruption as exception: + node.set_process_status(str(exception)) + return self.retrieve(monitor_result=self._monitor_result) + except (plumpy.futures.CancelledError, asyncio.CancelledError): + node.set_process_status(f'Transport task {self._command} was cancelled') + raise + except plumpy.process_states.Interruption: + node.set_process_status(f'Transport task {self._command} was interrupted') + raise + else: + node.set_process_status(None) + return result + finally: + # If we were trying to kill but we didn't deal with it, make sure it's set here + if self._killing and not self._killing.done(): + self._killing.set_result(False) + + async def _monitor_job(self, node, transport_queue, monitors) -> CalcJobMonitorResult | None: + """Process job monitors if any were specified as inputs.""" + if monitors is None: + return None + + if self._monitor_result and self._monitor_result.action == CalcJobMonitorAction.DISABLE_ALL: + return None + + monitor_result = await self._launch_task(task_monitor_job, node, transport_queue, monitors=monitors) + + if monitor_result and monitor_result.outputs: + for label, output in monitor_result.outputs.items(): + self.process.out(label, output) + self.process.update_outputs() + + if monitor_result and monitor_result.action == CalcJobMonitorAction.DISABLE_SELF: + monitors.monitors[monitor_result.key].disabled = True + + if monitor_result is not None: + self._monitor_result = monitor_result + + return monitor_result + + async def _kill_job(self, node, transport_queue) -> None: + """Kill the job.""" + await self._launch_task(task_kill_job, node, transport_queue) + if self._killing is not None: + self._killing.set_result(True) + else: + logger.info(f'killed CalcJob<{node.pk}> but async future was None') + + async def _launch_task(self, coro, *args, **kwargs): + """Launch a coroutine as a task, making sure to make it interruptable.""" + task_fn = functools.partial(coro, *args, **kwargs) + try: + self._task = interruptable_task(task_fn) + result = await self._task + return result + finally: + self._task = None + + def upload(self) -> 'Waiting': + """Return the `Waiting` state that will `upload` the `CalcJob`.""" + msg = 'Waiting for calculation folder upload' + return self.create_state( # type: ignore[return-value] + ProcessState.WAITING, None, msg=msg, data={'command': UPLOAD_COMMAND} + ) + + def submit(self) -> 'Waiting': + """Return the `Waiting` state that will `submit` the `CalcJob`.""" + msg = 'Waiting for scheduler submission' + return self.create_state( # type: ignore[return-value] + ProcessState.WAITING, None, msg=msg, data={'command': SUBMIT_COMMAND} + ) + + def update(self) -> 'Waiting': + """Return the `Waiting` state that will `update` the `CalcJob`.""" + msg = 'Waiting for scheduler update' + return self.create_state( # type: ignore[return-value] + ProcessState.WAITING, None, msg=msg, data={'command': UPDATE_COMMAND} + ) + + def stash(self, monitor_result: CalcJobMonitorResult | None = None) -> 'Waiting': + """Return the `Waiting` state that will `stash` the `CalcJob`.""" + msg = 'Waiting to stash' + return self.create_state( # type: ignore[return-value] + ProcessState.WAITING, None, msg=msg, data={'command': STASH_COMMAND, 'monitor_result': monitor_result} + ) + + # def unstash(self, monitor_result: CalcJobMonitorResult | None = None) -> 'Waiting': + # """Return the `Waiting` state that will `unstash` the `CalcJob`.""" + # msg = 'Waiting to unstash' + # return self.create_state( # type: ignore[return-value] + # ProcessState.WAITING, None, msg=msg, data={'command': UNSTASH_COMMAND, 'monitor_result': monitor_result} + # ) + + def retrieve(self, monitor_result: CalcJobMonitorResult | None = None) -> 'Waiting': + """Return the `Waiting` state that will `retrieve` the `CalcJob`.""" + msg = 'Waiting to retrieve' + return self.create_state( # type: ignore[return-value] + ProcessState.WAITING, None, msg=msg, data={'command': RETRIEVE_COMMAND, 'monitor_result': monitor_result} + ) + + def parse( + self, retrieved_temporary_folder: str, exit_code: ExitCode | None = None + ) -> plumpy.process_states.Running: + """Return the `Running` state that will parse the `CalcJob`. + + :param retrieved_temporary_folder: temporary folder used in retrieving that can be used during parsing. + """ + return self.create_state( # type: ignore[return-value] + ProcessState.RUNNING, self.process.parse, retrieved_temporary_folder, exit_code + ) + + def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ignore[override] + """Interrupt the `Waiting` state by calling interrupt on the transport task `InterruptableFuture`.""" + if self._task is not None: + self._task.interrupt(reason) + + if isinstance(reason, plumpy.process_states.KillInterruption): + if self._killing is None: + self._killing = plumpy.futures.Future() + return self._killing + + return None diff --git a/src/airflow_provider_aiida/aiida_core/transport.py b/src/airflow_provider_aiida/aiida_core/transport.py index fb9f5df..59ae4a3 100644 --- a/src/airflow_provider_aiida/aiida_core/transport.py +++ b/src/airflow_provider_aiida/aiida_core/transport.py @@ -105,11 +105,18 @@ def get_authinfo_from_airflow_connection(conn_id: str): def get_transport_queue() -> TransportQueue: - """Return a per-process shared TransportQueue instance.""" - global _TRANSPORT_QUEUE - if _TRANSPORT_QUEUE is None: - _TRANSPORT_QUEUE = TransportQueue() - return _TRANSPORT_QUEUE + """Return a TransportQueue instance using the current event loop. + + Note: Always creates a new TransportQueue to ensure it uses the current + event loop. This is necessary because Airflow triggers run in different + async contexts with different event loops. + """ + import asyncio + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.get_event_loop() + return TransportQueue(loop=loop) def get_authinfo_cached(conn_id: str): diff --git a/src/airflow_provider_aiida/example_dags/arithmetic_aiida_native.py b/src/airflow_provider_aiida/example_dags/arithmetic_aiida_native.py new file mode 100644 index 0000000..930df08 --- /dev/null +++ b/src/airflow_provider_aiida/example_dags/arithmetic_aiida_native.py @@ -0,0 +1,411 @@ +""" +Example DAG using AsyncAiiDACalcJobTaskGroup with Native AiiDA Task Functions + +This demonstrates native AiiDA CalcJob execution in Airflow using the task functions +from aiida-core's calcjobs.tasks module through async triggers. + +NOTE: This is a simplified example that demonstrates the structure. In a real implementation, +you would need proper AiiDA initialization, database setup, and computer configuration. +""" + +from datetime import datetime +from pathlib import Path +from typing import Any + +from airflow import DAG +from airflow.decorators import task +from airflow.models.param import Param + +from aiida.engine import CalcJobProcessSpec +from aiida.orm import Int, Str + +from airflow_provider_aiida.taskgroups.async_aiida_calcjob import SimpleAsyncAiiDACalcJobTaskGroup + + +class AiiDAAddJobTaskGroup(SimpleAsyncAiiDACalcJobTaskGroup): + """Addition job using native AiiDA CalcJob task functions.""" + + @classmethod + def define(cls, spec: CalcJobProcessSpec): + """Define the input specification using AiiDA's spec system.""" + super().define(spec) + spec.input('x', valid_type=Int, help='First operand') + spec.input('y', valid_type=Int, help='Second operand') + spec.input('computer', valid_type=Str, help='Computer label') + spec.input('code', valid_type=Str, required=False, help='Code label') + + def prepare_for_submission(self, folder): + """Prepare the calculation for submission. + + Write input files and return CalcInfo with job details. + """ + from aiida.common.datastructures import CalcInfo, CodeInfo + + # Get input values from self.resolved_inputs (similar to AiiDA's self.inputs) + x = self.resolved_inputs.x + y = self.resolved_inputs.y + + # Write a simple bash script that does addition + with folder.open('add.sh', 'w') as f: + f.write('#!/bin/bash\n') + f.write(f'# Addition script: {x} + {y}\n') + f.write(f'echo $(({x} + {y})) > result.out\n') + + # Make script executable + import os + os.chmod(folder.get_abs_path('add.sh'), 0o755) + + # Create CalcInfo + calc_info = CalcInfo() + + # Check if code input is provided + if hasattr(self.resolved_inputs, 'code') and self.resolved_inputs.code: + from aiida.orm import load_code + code = load_code(self.resolved_inputs.code) + + code_info = CodeInfo() + code_info.code_uuid = code.uuid + # Use bash -c to execute the script with proper quoting + code_info.cmdline_params = ['-c', 'bash add.sh'] + code_info.stdout_name = 'add.stdout' + code_info.stderr_name = 'add.stderr' + calc_info.codes_info = [code_info] + else: + # No code provided - skip code execution for this simple example + calc_info.codes_info = [] + + # Files to retrieve after job completes + calc_info.retrieve_list = ['result.out', 'add.stdout', 'add.stderr'] + + return calc_info + + def parse(self, retrieved_temporary_folder: str, **context) -> dict[str, Any]: + """Parse the addition results from retrieved files. + + Args: + retrieved_temporary_folder: Path to folder containing retrieved files (not used, kept for compatibility) + + Returns: + dict containing exit_status, results, and node_pk + """ + from aiida.orm import load_node + + # Get the node_pk from prepare task + ti = context['task_instance'] + node_pk = ti.xcom_pull(task_ids=f'{self.group_id}.prepare_calcjob') + + # Load the node + node = load_node(node_pk) + + print(f"DEBUG: Node PK = {node_pk}") + print(f"DEBUG: Node state = {node.get_state()}") + print(f"DEBUG: Job ID = {node.get_job_id()}") + + # Parse retrieved files from node.outputs.retrieved + results = {} + exit_status = 0 + + try: + # Access retrieved files via AiiDA node outputs + if hasattr(node.outputs, 'retrieved'): + retrieved = node.outputs.retrieved + print(f"DEBUG: Retrieved folder found in node.outputs.retrieved") + print(f"DEBUG: Files in retrieved: {list(retrieved.base.repository.list_object_names())}") + + # Read result.out from retrieved folder + if 'result.out' in retrieved.base.repository.list_object_names(): + with retrieved.base.repository.open('result.out', 'r') as f: + result_content = f.read().strip() + result_value = int(result_content) + results['sum'] = result_value + + # Get input values from node attributes + x = node.base.attributes.get('x') + y = node.base.attributes.get('y') + print(f"Addition result ({x} + {y}): {result_value}") + + # Create output node and link it to the CalcJob + from aiida.orm import Int + from aiida.common.links import LinkType + output = Int(result_value) + output.base.links.add_incoming(node, LinkType.CREATE, link_label='sum') + output.store() + else: + print(f"ERROR: result.out not found in retrieved folder") + exit_status = 1 + else: + print(f"ERROR: No retrieved output found on node") + exit_status = 1 + + except Exception as e: + print(f"ERROR parsing results: {e}") + import traceback + traceback.print_exc() + exit_status = 2 + + return { + 'exit_status': exit_status, + 'results': results, + 'node_pk': node_pk + } + + +class AiiDAMultiplyJobTaskGroup(SimpleAsyncAiiDACalcJobTaskGroup): + """Multiplication job using native AiiDA CalcJob task functions.""" + + @classmethod + def define(cls, spec: CalcJobProcessSpec): + """Define the input specification using AiiDA's spec system.""" + super().define(spec) + spec.input('x', valid_type=Int, help='First operand') + spec.input('y', valid_type=Int, help='Second operand') + spec.input('computer', valid_type=Str, help='Computer label') + spec.input('code', valid_type=Str, required=False, help='Code label') + + def prepare_for_submission(self, folder): + """Prepare the calculation for submission. + + Write input files and return CalcInfo with job details. + """ + from aiida.common.datastructures import CalcInfo, CodeInfo + + # Get input values from self.resolved_inputs (similar to AiiDA's self.inputs) + x = self.resolved_inputs.x + y = self.resolved_inputs.y + + # Write a simple bash script that does multiplication + with folder.open('multiply.sh', 'w') as f: + f.write('#!/bin/bash\n') + f.write(f'# Multiplication script: {x} * {y}\n') + f.write(f'echo $(({x} * {y})) > multiply_result.out\n') + f.write(f'echo "Performed {x} * {y} = $(({x} * {y}))" > operation.log\n') + + # Make script executable + import os + os.chmod(folder.get_abs_path('multiply.sh'), 0o755) + + # Create CalcInfo + calc_info = CalcInfo() + + # Check if code input is provided + if hasattr(self.resolved_inputs, 'code') and self.resolved_inputs.code: + from aiida.orm import load_code + code = load_code(self.resolved_inputs.code) + + code_info = CodeInfo() + code_info.code_uuid = code.uuid + # Use bash -c to execute the script with proper quoting + code_info.cmdline_params = ['-c', 'bash multiply.sh'] + code_info.stdout_name = 'multiply.stdout' + code_info.stderr_name = 'multiply.stderr' + calc_info.codes_info = [code_info] + else: + # No code provided - skip code execution for this simple example + calc_info.codes_info = [] + + # Files to retrieve after job completes + calc_info.retrieve_list = ['multiply_result.out', 'operation.log', 'multiply.stdout', 'multiply.stderr'] + + return calc_info + + def parse(self, retrieved_temporary_folder: str, **context) -> dict[str, Any]: + """Parse the multiplication results from retrieved files.""" + from aiida.orm import load_node + + # Get the node_pk from prepare task + ti = context['task_instance'] + node_pk = ti.xcom_pull(task_ids=f'{self.group_id}.prepare_calcjob') + + # Load the node + node = load_node(node_pk) + + # Parse retrieved files from node.outputs.retrieved + results = {} + exit_status = 0 + + try: + # Access retrieved files via AiiDA node outputs + if hasattr(node.outputs, 'retrieved'): + retrieved = node.outputs.retrieved + print(f"DEBUG: Retrieved folder found in node.outputs.retrieved") + print(f"DEBUG: Files in retrieved: {list(retrieved.base.repository.list_object_names())}") + + # Read multiply_result.out from retrieved folder + if 'multiply_result.out' in retrieved.base.repository.list_object_names(): + with retrieved.base.repository.open('multiply_result.out', 'r') as f: + result_content = f.read().strip() + result_value = int(result_content) + results['product'] = result_value + + # Get input values from node attributes + x = node.base.attributes.get('x') + y = node.base.attributes.get('y') + print(f"Multiplication result ({x} * {y}): {result_value}") + + # Create output node and link it to the CalcJob + from aiida.orm import Int + from aiida.common.links import LinkType + output = Int(result_value) + output.base.links.add_incoming(node, LinkType.CREATE, link_label='product') + output.store() + else: + print(f"ERROR: multiply_result.out not found in retrieved folder") + exit_status = 1 + + # Read operation.log if it exists + if 'operation.log' in retrieved.base.repository.list_object_names(): + with retrieved.base.repository.open('operation.log', 'r') as f: + operation_log = f.read().strip() + results['operation_log'] = operation_log + print(f"Operation log: {operation_log}") + else: + print(f"ERROR: No retrieved output found on node") + exit_status = 1 + + except Exception as e: + print(f"ERROR parsing results: {e}") + import traceback + traceback.print_exc() + exit_status = 2 + + return { + 'exit_status': exit_status, + 'results': results, + 'node_pk': node_pk + } + + +# Create DAG +default_args = { + 'owner': 'airflow', + 'depends_on_past': False, + 'start_date': datetime(2025, 1, 1), + 'email_on_failure': False, + 'email_on_retry': False, + 'retries': 0, +} + +with DAG( + 'arithmetic_aiida_native', + default_args=default_args, + description='Native AiiDA CalcJob TaskGroup using aiida-core task functions', + schedule=None, + catchup=False, + tags=['aiida', 'arithmetic', 'calcjob', 'native', 'async'], + params={ + "computer": Param("localhost", type="string", description="AiiDA computer label"), + "code": Param("bash", type="string", description="AiiDA code label for bash execution"), + "add_x": Param(8, type="integer", description="First operand for addition"), + "add_y": Param(4, type="integer", description="Second operand for addition"), + "multiply_x": Param(6, type="integer", description="First operand for multiplication"), + "multiply_y": Param(9, type="integer", description="Second operand for multiplication"), + } +) as dag: + + # Create task groups using native AiiDA task functions + # These will execute via the async triggers in the Airflow triggerer + # Pass template strings - they'll be converted to AiiDA types at task execution time + add_job = AiiDAAddJobTaskGroup( + group_id="aiida_addition_job", + x="{{ params.add_x }}", + y="{{ params.add_y }}", + computer="{{ params.computer }}", + code="{{ params.code }}", + submit_script_filename="add.sh", + ) + + multiply_job = AiiDAMultiplyJobTaskGroup( + group_id="aiida_multiplication_job", + x="{{ params.multiply_x }}", + y="{{ params.multiply_y }}", + computer="{{ params.computer }}", + code="{{ params.code }}", + submit_script_filename="multiply.sh", + ) + + @task + def combine_results(**context): + """Combine results from both AiiDA CalcJobs. + + This demonstrates accessing results from AiiDA CalcJobs + that were executed via the async triggers. + """ + task_instance = context['task_instance'] + + # Pull results from both parse tasks + add_result = task_instance.xcom_pull( + task_ids='aiida_addition_job.parse' + ) + multiply_result = task_instance.xcom_pull( + task_ids='aiida_multiplication_job.parse' + ) + + # Create combined result + combined = { + 'addition': { + 'exit_status': add_result['exit_status'], + 'success': add_result['exit_status'] == 0, + 'results': add_result['results'], + 'node_pk': add_result['node_pk'] + }, + 'multiplication': { + 'exit_status': multiply_result['exit_status'], + 'success': multiply_result['exit_status'] == 0, + 'results': multiply_result['results'], + 'node_pk': multiply_result['node_pk'] + }, + 'overall_success': ( + add_result['exit_status'] == 0 and + multiply_result['exit_status'] == 0 + ) + } + + print("=" * 60) + print("COMBINED AIIDA CALCJOB RESULTS") + print("=" * 60) + print(f"Addition CalcJob (Node {add_result['node_pk']}):") + print(f" Status: {'SUCCESS' if combined['addition']['success'] else 'FAILED'}") + print(f" Results: {add_result['results']}") + print() + print(f"Multiplication CalcJob (Node {multiply_result['node_pk']}):") + print(f" Status: {'SUCCESS' if combined['multiplication']['success'] else 'FAILED'}") + print(f" Results: {multiply_result['results']}") + print() + print(f"Overall: {'ALL SUCCEEDED' if combined['overall_success'] else 'SOME FAILED'}") + print("=" * 60) + + if not combined['overall_success']: + raise ValueError("One or more CalcJobs failed") + + return combined + + # Set up workflow: both jobs run in parallel, then combine results + combine_task = combine_results() + [add_job, multiply_job] >> combine_task + + +if __name__ == "__main__": + from aiida import load_profile + load_profile() + """Execute the DAG for testing/debugging.""" + from datetime import datetime + + print("=" * 60) + print("Testing arithmetic_aiida_native DAG") + print("=" * 60) + + # Test the DAG with default parameters + dag.test( + run_conf={ + "computer": "localhost", + "code": "bash", + "add_x": 8, + "add_y": 4, + "multiply_x": 6, + "multiply_y": 9, + } + ) + + print("\n" + "=" * 60) + print("DAG test completed!") + print("=" * 60) diff --git a/src/airflow_provider_aiida/operators/async_aiida_calcjob.py b/src/airflow_provider_aiida/operators/async_aiida_calcjob.py new file mode 100644 index 0000000..49f3358 --- /dev/null +++ b/src/airflow_provider_aiida/operators/async_aiida_calcjob.py @@ -0,0 +1,326 @@ +"""Airflow operators that defer to AiiDA CalcJob triggers. + +These operators provide async execution of AiiDA CalcJob transport tasks by deferring +to the corresponding triggers that wrap aiida-core's task functions. +""" + +from typing import Any + +from airflow.models import BaseOperator +from airflow.utils.context import Context + +from airflow_provider_aiida.triggers.async_aiida_calcjob import ( + AiiDAUploadTrigger, + AiiDASubmitTrigger, + AiiDAUpdateTrigger, + AiiDAMonitorTrigger, + AiiDARetrieveTrigger, + AiiDAStashTrigger, + AiiDAUnstashTrigger, + AiiDAKillTrigger, +) + + +class AiiDAAsyncUploadOperator(BaseOperator): + """Operator that defers to AiiDAUploadTrigger to upload CalcJob files. + + This operator executes the AiiDA task_upload_job function asynchronously. + """ + + template_fields = ["node_pk"] + + def __init__(self, node_pk: int, **kwargs): + """Initialize the upload operator. + + :param node_pk: Primary key of the CalcJobNode to upload + """ + super().__init__(**kwargs) + self.node_pk = node_pk + + def execute(self, context: Context): + """Defer to the upload trigger.""" + self.defer( + trigger=AiiDAUploadTrigger(node_pk=self.node_pk), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict): + """Handle the trigger completion.""" + if event["status"] == "error": + error_msg = f"Upload failed: {event['message']}" + if "traceback" in event: + error_msg += f"\n\nFull traceback:\n{event['traceback']}" + raise ValueError(error_msg) + + skip_submit = event.get("skip_submit", False) + self.log.info(f"Upload completed successfully. Skip submit: {skip_submit}") + return {"skip_submit": skip_submit} + + +class AiiDAAsyncSubmitOperator(BaseOperator): + """Operator that defers to AiiDASubmitTrigger to submit a CalcJob. + + This operator executes the AiiDA task_submit_job function asynchronously. + """ + + template_fields = ["node_pk"] + + def __init__(self, node_pk: int, **kwargs): + """Initialize the submit operator. + + :param node_pk: Primary key of the CalcJobNode to submit + """ + super().__init__(**kwargs) + self.node_pk = node_pk + + def execute(self, context: Context): + """Defer to the submit trigger.""" + self.defer( + trigger=AiiDASubmitTrigger(node_pk=self.node_pk), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict): + """Handle the trigger completion.""" + if event["status"] == "error": + error_msg = f"Submit failed: {event['message']}" + if "traceback" in event: + error_msg += f"\n\nFull traceback:\n{event['traceback']}" + raise ValueError(error_msg) + + job_id = event.get("job_id") + self.log.info(f"Submit completed successfully. Job ID: {job_id}") + return {"job_id": job_id} + + +class AiiDAAsyncUpdateOperator(BaseOperator): + """Operator that defers to AiiDAUpdateTrigger to monitor CalcJob status. + + This operator executes the AiiDA task_update_job function asynchronously, + polling until the job is complete. + """ + + template_fields = ["node_pk", "sleep_interval"] + + def __init__(self, node_pk: int, sleep_interval: int = 5, **kwargs): + """Initialize the update operator. + + :param node_pk: Primary key of the CalcJobNode to update + :param sleep_interval: Seconds to sleep between update checks + """ + super().__init__(**kwargs) + self.node_pk = node_pk + self.sleep_interval = sleep_interval + + def execute(self, context: Context): + """Defer to the update trigger.""" + self.defer( + trigger=AiiDAUpdateTrigger( + node_pk=self.node_pk, + sleep_interval=self.sleep_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict): + """Handle the trigger completion.""" + if event["status"] == "error": + error_msg = f"Update failed: {event['message']}" + if "traceback" in event: + error_msg += f"\n\nFull traceback:\n{event['traceback']}" + raise ValueError(error_msg) + + job_done = event.get("job_done", False) + self.log.info(f"Update completed successfully. Job done: {job_done}") + return {"job_done": job_done} + + +class AiiDAAsyncMonitorOperator(BaseOperator): + """Operator that defers to AiiDAMonitorTrigger to monitor CalcJob. + + This operator executes the AiiDA task_monitor_job function asynchronously. + """ + + template_fields = ["node_pk", "monitors_pk"] + + def __init__(self, node_pk: int, monitors_pk: int | None = None, **kwargs): + """Initialize the monitor operator. + + :param node_pk: Primary key of the CalcJobNode to monitor + :param monitors_pk: Primary key of the CalcJobMonitors node (if applicable) + """ + super().__init__(**kwargs) + self.node_pk = node_pk + self.monitors_pk = monitors_pk + + def execute(self, context: Context): + """Defer to the monitor trigger.""" + self.defer( + trigger=AiiDAMonitorTrigger( + node_pk=self.node_pk, + monitors_pk=self.monitors_pk, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict): + """Handle the trigger completion.""" + if event["status"] == "error": + error_msg = f"Monitor failed: {event['message']}" + if "traceback" in event: + error_msg += f"\n\nFull traceback:\n{event['traceback']}" + raise ValueError(error_msg) + + self.log.info("Monitor completed successfully") + # Return monitor result details + return { + "action": event.get("action"), + "message": event.get("message"), + "retrieve": event.get("retrieve"), + "parse": event.get("parse"), + } + + +class AiiDAAsyncRetrieveOperator(BaseOperator): + """Operator that defers to AiiDARetrieveTrigger to retrieve CalcJob files. + + This operator executes the AiiDA task_retrieve_job function asynchronously. + """ + + template_fields = ["node_pk", "retrieved_temporary_folder"] + + def __init__(self, node_pk: int, retrieved_temporary_folder: str, **kwargs): + """Initialize the retrieve operator. + + :param node_pk: Primary key of the CalcJobNode to retrieve + :param retrieved_temporary_folder: Path to temporary folder for retrieved files + """ + super().__init__(**kwargs) + self.node_pk = node_pk + self.retrieved_temporary_folder = retrieved_temporary_folder + + def execute(self, context: Context): + """Defer to the retrieve trigger.""" + self.defer( + trigger=AiiDARetrieveTrigger( + node_pk=self.node_pk, + retrieved_temporary_folder=self.retrieved_temporary_folder, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict): + """Handle the trigger completion.""" + if event["status"] == "error": + error_msg = f"Retrieve failed: {event['message']}" + if "traceback" in event: + error_msg += f"\n\nFull traceback:\n{event['traceback']}" + raise ValueError(error_msg) + + retrieved = event.get("retrieved", False) + self.log.info(f"Retrieve completed successfully. Retrieved: {retrieved}") + return {"retrieved": retrieved} + + +class AiiDAAsyncStashOperator(BaseOperator): + """Operator that defers to AiiDAStashTrigger to stash CalcJob files. + + This operator executes the AiiDA task_stash_job function asynchronously. + """ + + template_fields = ["node_pk"] + + def __init__(self, node_pk: int, **kwargs): + """Initialize the stash operator. + + :param node_pk: Primary key of the CalcJobNode to stash + """ + super().__init__(**kwargs) + self.node_pk = node_pk + + def execute(self, context: Context): + """Defer to the stash trigger.""" + self.defer( + trigger=AiiDAStashTrigger(node_pk=self.node_pk), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict): + """Handle the trigger completion.""" + if event["status"] == "error": + error_msg = f"Stash failed: {event['message']}" + if "traceback" in event: + error_msg += f"\n\nFull traceback:\n{event['traceback']}" + raise ValueError(error_msg) + + self.log.info("Stash completed successfully") + + +class AiiDAAsyncUnstashOperator(BaseOperator): + """Operator that defers to AiiDAUnstashTrigger to unstash CalcJob files. + + This operator executes the AiiDA task_unstash_job function asynchronously. + """ + + template_fields = ["node_pk"] + + def __init__(self, node_pk: int, **kwargs): + """Initialize the unstash operator. + + :param node_pk: Primary key of the CalcJobNode to unstash + """ + super().__init__(**kwargs) + self.node_pk = node_pk + + def execute(self, context: Context): + """Defer to the unstash trigger.""" + self.defer( + trigger=AiiDAUnstashTrigger(node_pk=self.node_pk), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict): + """Handle the trigger completion.""" + if event["status"] == "error": + error_msg = f"Unstash failed: {event['message']}" + if "traceback" in event: + error_msg += f"\n\nFull traceback:\n{event['traceback']}" + raise ValueError(error_msg) + + self.log.info("Unstash completed successfully") + + +class AiiDAAsyncKillOperator(BaseOperator): + """Operator that defers to AiiDAKillTrigger to kill a CalcJob. + + This operator executes the AiiDA task_kill_job function asynchronously. + """ + + template_fields = ["node_pk"] + + def __init__(self, node_pk: int, **kwargs): + """Initialize the kill operator. + + :param node_pk: Primary key of the CalcJobNode to kill + """ + super().__init__(**kwargs) + self.node_pk = node_pk + + def execute(self, context: Context): + """Defer to the kill trigger.""" + self.defer( + trigger=AiiDAKillTrigger(node_pk=self.node_pk), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict): + """Handle the trigger completion.""" + if event["status"] == "error": + error_msg = f"Kill failed: {event['message']}" + if "traceback" in event: + error_msg += f"\n\nFull traceback:\n{event['traceback']}" + raise ValueError(error_msg) + + killed = event.get("killed", False) + self.log.info(f"Kill completed successfully. Killed: {killed}") + return {"killed": killed} diff --git a/src/airflow_provider_aiida/taskgroups/async_aiida_calcjob.py b/src/airflow_provider_aiida/taskgroups/async_aiida_calcjob.py new file mode 100644 index 0000000..4e3d887 --- /dev/null +++ b/src/airflow_provider_aiida/taskgroups/async_aiida_calcjob.py @@ -0,0 +1,465 @@ +""" +Async AiiDA CalcJob TaskGroup using AiiDA Core Task Functions + +This taskgroup uses the AiiDA async operators that directly wrap aiida-core's +calcjob task functions, providing native AiiDA CalcJob execution in Airflow. +""" + +from abc import ABC, abstractmethod +from typing import Any + +from airflow.utils.task_group import TaskGroup +from airflow.operators.python import PythonOperator, BranchPythonOperator + +from aiida.engine import CalcJobProcessSpec + +from airflow_provider_aiida.operators.async_aiida_calcjob import ( + AiiDAAsyncUploadOperator, + AiiDAAsyncSubmitOperator, + AiiDAAsyncUpdateOperator, + AiiDAAsyncMonitorOperator, + AiiDAAsyncRetrieveOperator, + AiiDAAsyncStashOperator, + AiiDAAsyncUnstashOperator, +) + + +class AsyncAiiDACalcJobTaskGroup(TaskGroup, ABC): + """ + Abstract TaskGroup for async AiiDA CalcJob workflows using deferrable operators. + + This version directly uses AiiDA's calcjob task functions through triggers, + providing native AiiDA CalcJob execution with Airflow's async capabilities. + + The workflow follows AiiDA's CalcJob state machine: + - UPLOAD -> SUBMIT -> UPDATE -> STASH -> RETRIEVE -> PARSE + + Or if skip_submit is True: + - UPLOAD -> STASH -> RETRIEVE -> PARSE + + Subclasses must implement define() class method and create_calcjob() and parse() methods. + """ + + def __init__( + self, + group_id: str, + node_pk: int | None = None, + retrieved_temporary_folder: str | None = None, + enable_stash: bool = False, + enable_unstash: bool = False, + enable_monitors: bool = False, + monitors_pk: int | None = None, + update_sleep_interval: int = 5, + submit_script_filename: str = '_aiidasubmit.sh', + **kwargs + ): + """Initialize the AiiDA CalcJob TaskGroup. + + :param group_id: Unique identifier for this task group + :param node_pk: Primary key of existing CalcJobNode (if reusing existing node) + :param retrieved_temporary_folder: Path for retrieved files + :param enable_stash: Whether to enable stashing (optional) + :param enable_unstash: Whether to enable unstashing (optional) + :param enable_monitors: Whether to enable monitoring + :param monitors_pk: Primary key of CalcJobMonitors node + :param update_sleep_interval: Seconds between update checks + :param submit_script_filename: Name of the submit script file + """ + super().__init__(group_id=group_id) + self.node_pk = node_pk + self.retrieved_temporary_folder = retrieved_temporary_folder or f"/tmp/aiida_retrieved_{group_id}" + self.enable_stash = enable_stash + self.enable_unstash = enable_unstash + self.enable_monitors = enable_monitors + self.monitors_pk = monitors_pk + self.update_sleep_interval = update_sleep_interval + self.submit_script_filename = submit_script_filename + + # Initialize spec and store inputs using AiiDA's CalcJobProcessSpec + self._spec = CalcJobProcessSpec() + self.define(self._spec) + + # Store inputs from kwargs based on spec + # These may be template strings that will be rendered at task execution time + self.inputs = {} + for port_name in self._spec.inputs.ports.keys(): + if port_name in kwargs: + self.inputs[port_name] = kwargs[port_name] + + # Build the task group when instantiated + self._build_tasks() + + def _create_calcjob_and_prepare(self, inputs: dict, **context) -> int: + """Create CalcJobNode, prepare for submission, and store everything. + + This combines node creation and preparation into one task so we can + store files in the repository before the node is stored. + + Args: + inputs: Dictionary of inputs (templates already rendered by Airflow) + + Returns: + int: Primary key of the created CalcJobNode + """ + from aiida.orm import CalcJobNode, load_computer, to_aiida_type + from aiida.common.datastructures import CalcJobState + from aiida.common.folders import SandboxFolder + from aiida.manage.configuration import get_config_option + from types import SimpleNamespace + + # Get computer from inputs (should be in spec) + if 'computer' not in inputs: + raise ValueError("CalcJob requires 'computer' input in spec") + + # Convert inputs to AiiDA nodes using to_aiida_type + # At this point, templates are already rendered by Airflow + resolved_inputs = {} + for key, value in inputs.items(): + # If it's already an AiiDA node, use it + if hasattr(value, 'value'): + resolved_inputs[key] = value + else: + # Convert to AiiDA type (value is now the rendered string) + resolved_inputs[key] = to_aiida_type(value) + + computer_label = resolved_inputs['computer'].value + computer = load_computer(computer_label) + + # Create CalcJobNode (but don't store yet) + node = CalcJobNode(computer=computer) + node.set_attribute('process_label', self.__class__.__name__) + + # Set CalcJob metadata options (required for proper job submission) + node.set_option('submit_script_filename', self.submit_script_filename) + node.set_option('scheduler_stdout', '_scheduler-stdout.txt') + node.set_option('scheduler_stderr', '_scheduler-stderr.txt') + node.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) + + # Store all inputs as attributes + for key, aiida_node in resolved_inputs.items(): + if key != 'computer': # Computer is already set + # Extract value from AiiDA Data nodes + attr_value = aiida_node.value if hasattr(aiida_node, 'value') else aiida_node + node.base.attributes.set(key, attr_value) + + # Set initial state + node.set_state(CalcJobState.UPLOADING) + + # Now prepare for submission BEFORE storing the node + filepath_sandbox = get_config_option('storage.sandbox') or None + + # Create a namespace with resolved inputs (similar to AiiDA's self.inputs) + resolved_inputs_ns = SimpleNamespace() + for key, aiida_node in resolved_inputs.items(): + if key == 'computer': + value = computer_label + else: + value = aiida_node.value if hasattr(aiida_node, 'value') else aiida_node + setattr(resolved_inputs_ns, key, value) + + # Store resolved inputs on self so prepare_for_submission can access them + self.resolved_inputs = resolved_inputs_ns + + with SandboxFolder(filepath_sandbox) as folder: + # Call the user-defined prepare_for_submission + calc_info = self.prepare_for_submission(folder) + + # Set uuid on calc_info (required by execmanager) - this is what presubmit does + calc_info.uuid = str(node.uuid) + + # Update retrieve lists on the node (similar to presubmit) + # This adds scheduler stdout/stderr to the retrieve list + retrieve_list = calc_info.retrieve_list or [] + node.set_retrieve_list(retrieve_list) + + retrieve_temporary_list = calc_info.retrieve_temporary_list or [] + node.set_retrieve_temporary_list(retrieve_temporary_list) + + # Store the folder contents in the node's repository (node not stored yet) + node.put_object_from_tree(folder.abspath, '') + + # Store CalcInfo details as node attributes + node.base.attributes.set('_calc_info_uuid', calc_info.uuid if hasattr(calc_info, 'uuid') else None) + node.base.attributes.set('_calc_info_skip_submit', calc_info.skip_submit or False) + node.base.attributes.set('_calc_info_codes_info', [ + { + 'code_uuid': str(ci.code_uuid) if hasattr(ci, 'code_uuid') and ci.code_uuid else None, + 'cmdline_params': ci.cmdline_params, + 'stdin_name': ci.stdin_name, + 'stdout_name': ci.stdout_name, + 'stderr_name': ci.stderr_name, + 'join_files': ci.join_files, + } for ci in (calc_info.codes_info or []) + ]) + + # Store retrieve lists + node.set_retrieve_list(calc_info.retrieve_list or []) + node.set_retrieve_temporary_list(calc_info.retrieve_temporary_list or []) + + # NOW store the node (after repository is populated) + node.store() + + return node.pk + + + @classmethod + def define(cls, spec: CalcJobProcessSpec): + """Define the input/output specification using AiiDA's CalcJobProcessSpec. + + Subclasses should override this to specify their inputs. + + :param spec: CalcJobProcessSpec to define inputs/outputs on + """ + pass + + @abstractmethod + def prepare_for_submission(self, folder) -> 'CalcInfo': + """Prepare the calculation for submission. + + This method should be implemented by subclasses to: + 1. Write input files to the folder + 2. Create and return a CalcInfo object with job submission details + + Access inputs via self.resolved_inputs.x, self.resolved_inputs.y, etc. + + Args: + folder: A SandboxFolder where input files should be written + + Returns: + CalcInfo: Object containing submission details (codes to execute, files to copy, etc.) + """ + pass + + def _build_tasks(self): + """Build all tasks within this task group following AiiDA's CalcJob workflow.""" + + # Task to create CalcJobNode and prepare for submission (only if node_pk not provided) + if not self.node_pk: + prepare_task = PythonOperator( + task_id='prepare_calcjob', + python_callable=self._create_calcjob_and_prepare, + op_kwargs={'inputs': self.inputs}, + task_group=self, + ) + # Get the node_pk to use downstream + node_pk_ref = prepare_task.output + else: + node_pk_ref = self.node_pk + + # Optional: Unstash task + if self.enable_unstash: + unstash_op = AiiDAAsyncUnstashOperator( + task_id="unstash", + node_pk=node_pk_ref, + task_group=self, + ) + + # Upload task + upload_op = AiiDAAsyncUploadOperator( + task_id="upload", + node_pk=node_pk_ref, + task_group=self, + ) + + # Branch based on skip_submit flag + branch_task = BranchPythonOperator( + task_id="check_skip_submit", + python_callable=self._check_skip_submit, + task_group=self, + ) + + # Submit task (only if not skipping) + submit_op = AiiDAAsyncSubmitOperator( + task_id="submit", + node_pk=node_pk_ref, + task_group=self, + ) + + # Update task (monitor job status) + update_op = AiiDAAsyncUpdateOperator( + task_id="update", + node_pk=node_pk_ref, + sleep_interval=self.update_sleep_interval, + task_group=self, + ) + + # Optional: Monitor task + if self.enable_monitors: + monitor_op = AiiDAAsyncMonitorOperator( + task_id="monitor", + node_pk=node_pk_ref, + monitors_pk=self.monitors_pk, + task_group=self, + ) + + # Optional: Stash task + if self.enable_stash: + stash_op = AiiDAAsyncStashOperator( + task_id="stash", + node_pk=node_pk_ref, + task_group=self, + ) + + # Retrieve task + retrieve_op = AiiDAAsyncRetrieveOperator( + task_id="retrieve", + node_pk=node_pk_ref, + retrieved_temporary_folder=self.retrieved_temporary_folder, + task_group=self, + ) + + # Parse task + parse_task = PythonOperator( + task_id='parse', + python_callable=self.parse, + op_kwargs={'retrieved_temporary_folder': self.retrieved_temporary_folder}, + task_group=self, + ) + + # Set up dependencies + if not self.node_pk: + if self.enable_unstash: + prepare_task >> unstash_op >> upload_op + else: + prepare_task >> upload_op + else: + if self.enable_unstash: + unstash_op >> upload_op + + upload_op >> branch_task + + # Full workflow: upload -> submit -> update -> [monitor] -> [stash] -> retrieve -> parse + branch_task >> submit_op >> update_op + + if self.enable_monitors: + update_op >> monitor_op + next_task = monitor_op + else: + next_task = update_op + + if self.enable_stash: + next_task >> stash_op >> retrieve_op + else: + next_task >> retrieve_op + + # Skip submit workflow: upload -> [stash] -> retrieve -> parse + if self.enable_stash: + branch_task >> stash_op + else: + branch_task >> retrieve_op + + retrieve_op >> parse_task + + def _create_calcjob_wrapper(self, **context): + """Wrapper to create the CalcJobNode generically from spec inputs.""" + if self.node_pk: + return {"node_pk": self.node_pk} + + # Generic CalcJobNode creation using the spec inputs + from aiida.orm import CalcJobNode, load_computer, to_aiida_type + from aiida.common.datastructures import CalcJobState + + # Get computer from inputs (should be in spec) + if 'computer' not in self.inputs: + raise ValueError("CalcJob requires 'computer' input in spec") + + # Convert inputs to AiiDA nodes using to_aiida_type + resolved_inputs = {} + for key, value in self.inputs.items(): + # If it's already an AiiDA node, use it + if hasattr(value, 'value'): + resolved_inputs[key] = value + else: + # Convert to AiiDA type + resolved_inputs[key] = to_aiida_type(value) + + computer_label = resolved_inputs['computer'].value + computer = load_computer(computer_label) + + # Create CalcJobNode + node = CalcJobNode(computer=computer) + node.set_attribute('process_label', self.__class__.__name__) + + # Store all inputs as attributes + for key, aiida_node in resolved_inputs.items(): + if key != 'computer': # Computer is already set + # Extract value from AiiDA Data nodes + attr_value = aiida_node.value if hasattr(aiida_node, 'value') else aiida_node + node.base.attributes.set(key, attr_value) + + # Set initial state + node.set_state(CalcJobState.UPLOADING) + + # Store the node + node.store() + + return {"node_pk": node.pk} + + def _check_skip_submit(self, **context): + """Check if we should skip the submit step based on upload results.""" + ti = context['task_instance'] + upload_result = ti.xcom_pull(task_ids=f'{self.group_id}.upload') + + if upload_result and upload_result.get('skip_submit'): + # Skip to stash (if enabled) or retrieve + if self.enable_stash: + return f'{self.group_id}.stash' + return f'{self.group_id}.retrieve' + else: + # Continue with submit + return f'{self.group_id}.submit' + + @abstractmethod + def parse(self, retrieved_temporary_folder: str, **context) -> dict[str, Any]: + """Abstract method to parse job outputs. + + This method should: + 1. Read the retrieved files from the temporary folder + 2. Parse the results + 3. Store results in the AiiDA database if needed + 4. Return parsed results + + Args: + retrieved_temporary_folder: Path to folder containing retrieved files + + Returns: + dict: Dictionary containing parsed results + """ + pass + + +class SimpleAsyncAiiDACalcJobTaskGroup(AsyncAiiDACalcJobTaskGroup): + """ + Simplified version of AsyncAiiDACalcJobTaskGroup without optional features. + + This provides a minimal workflow: UPLOAD -> SUBMIT -> UPDATE -> RETRIEVE -> PARSE + """ + + def __init__( + self, + group_id: str, + node_pk: int | None = None, + retrieved_temporary_folder: str | None = None, + update_sleep_interval: int = 5, + submit_script_filename: str = '_aiidasubmit.sh', + **kwargs + ): + """Initialize the simple AiiDA CalcJob TaskGroup. + + :param group_id: Unique identifier for this task group + :param node_pk: Primary key of existing CalcJobNode (if reusing existing node) + :param retrieved_temporary_folder: Path for retrieved files + :param update_sleep_interval: Seconds between update checks + :param submit_script_filename: Name of the submit script file + """ + super().__init__( + group_id=group_id, + node_pk=node_pk, + retrieved_temporary_folder=retrieved_temporary_folder, + enable_stash=False, + enable_unstash=False, + enable_monitors=False, + update_sleep_interval=update_sleep_interval, + submit_script_filename=submit_script_filename, + **kwargs + ) diff --git a/src/airflow_provider_aiida/triggers/async_aiida_calcjob.py b/src/airflow_provider_aiida/triggers/async_aiida_calcjob.py new file mode 100644 index 0000000..a28060a --- /dev/null +++ b/src/airflow_provider_aiida/triggers/async_aiida_calcjob.py @@ -0,0 +1,423 @@ +"""Airflow triggers that wrap AiiDA CalcJob transport tasks. + +These triggers directly execute the task functions from aiida-core's calcjob tasks module, +allowing CalcJob operations to be performed asynchronously in the Airflow triggerer. +""" + +import asyncio +import logging +from typing import Any, AsyncIterator + +from airflow.triggers.base import BaseTrigger, TriggerEvent + +from airflow_provider_aiida.aiida_core.engine.calcjob.tasks import ( + task_upload_job, + task_submit_job, + task_update_job, + task_monitor_job, + task_retrieve_job, + task_stash_job, + task_unstash_job, + task_kill_job, +) +from aiida.engine.utils import InterruptableFuture +from aiida.orm import load_node +from airflow_provider_aiida.aiida_core.transport import get_transport_queue + +logger = logging.getLogger(__name__) + + +class AiiDAUploadTrigger(BaseTrigger): + """Trigger that executes the AiiDA task_upload_job function.""" + + def __init__(self, node_pk: int): + """Initialize the upload trigger. + + :param node_pk: Primary key of the CalcJobNode to upload + """ + super().__init__() + self.node_pk = node_pk + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the trigger for persistence.""" + return ( + "airflow_provider_aiida.triggers.async_aiida_calcjob.AiiDAUploadTrigger", + {"node_pk": self.node_pk}, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Execute the upload task.""" + try: + # Load AiiDA profile (triggers run in separate process) + from aiida import load_profile + load_profile() + + from aiida.common.datastructures import CalcInfo, CodeInfo + + # Load the CalcJobNode + node = load_node(self.node_pk) + + # Reconstruct calc_info from node attributes + calc_info = CalcInfo() + calc_info.uuid = node.base.attributes.get('_calc_info_uuid', str(node.uuid)) + calc_info.skip_submit = node.base.attributes.get('_calc_info_skip_submit', False) + + # Reconstruct codes_info if present + codes_info_data = node.base.attributes.get('_calc_info_codes_info', []) + calc_info.codes_info = [] + for ci_data in codes_info_data: + code_info = CodeInfo() + code_info.code_uuid = ci_data.get('code_uuid') + code_info.cmdline_params = ci_data.get('cmdline_params', []) + code_info.stdin_name = ci_data.get('stdin_name') + code_info.stdout_name = ci_data.get('stdout_name') + code_info.stderr_name = ci_data.get('stderr_name') + code_info.join_files = ci_data.get('join_files', False) + calc_info.codes_info.append(code_info) + + transport_queue = get_transport_queue() + cancellable = InterruptableFuture() + + skip_submit = await task_upload_job(node, transport_queue, cancellable, calc_info) + + yield TriggerEvent({ + "status": "success", + "skip_submit": skip_submit, + }) + except Exception as e: + import traceback + tb = traceback.format_exc() + logger.exception(f"Upload task failed for node {self.node_pk}") + yield TriggerEvent({"status": "error", "message": str(e), "traceback": tb}) + + +class AiiDASubmitTrigger(BaseTrigger): + """Trigger that executes the AiiDA task_submit_job function.""" + + def __init__(self, node_pk: int): + """Initialize the submit trigger. + + :param node_pk: Primary key of the CalcJobNode to submit + """ + super().__init__() + self.node_pk = node_pk + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the trigger for persistence.""" + return ( + "airflow_provider_aiida.triggers.async_aiida_calcjob.AiiDASubmitTrigger", + {"node_pk": self.node_pk}, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Execute the submit task.""" + try: + # Load AiiDA profile (triggers run in separate process) + from aiida import load_profile + load_profile() + + node = load_node(self.node_pk) + transport_queue = get_transport_queue() + cancellable = InterruptableFuture() + + job_id = await task_submit_job(node, transport_queue, cancellable) + + yield TriggerEvent({ + "status": "success", + "job_id": job_id, + }) + except Exception as e: + import traceback + tb = traceback.format_exc() + logger.exception(f"Submit task failed for node {self.node_pk}") + yield TriggerEvent({"status": "error", "message": str(e), "traceback": tb}) + + +class AiiDAUpdateTrigger(BaseTrigger): + """Trigger that executes the AiiDA task_update_job function. + + This trigger polls the job status until it's complete. + """ + + def __init__(self, node_pk: int, sleep_interval: int = 5): + """Initialize the update trigger. + + :param node_pk: Primary key of the CalcJobNode to update + :param sleep_interval: Seconds to sleep between update checks + """ + super().__init__() + self.node_pk = node_pk + self.sleep_interval = sleep_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the trigger for persistence.""" + return ( + "airflow_provider_aiida.triggers.async_aiida_calcjob.AiiDAUpdateTrigger", + { + "node_pk": self.node_pk, + "sleep_interval": self.sleep_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Execute the update task repeatedly until job is done.""" + try: + # Load AiiDA profile (triggers run in separate process) + from aiida import load_profile + load_profile() + + node = load_node(self.node_pk) + transport_queue = get_transport_queue() + from aiida.engine.processes.calcjobs.manager import JobManager + job_manager = JobManager(transport_queue) + cancellable = InterruptableFuture() + + job_done = False + while not job_done: + job_done = await task_update_job(node, job_manager, cancellable) + + if not job_done: + await asyncio.sleep(self.sleep_interval) + + yield TriggerEvent({ + "status": "success", + "job_done": True, + }) + except Exception as e: + import traceback + tb = traceback.format_exc() + logger.exception(f"Update task failed for node {self.node_pk}") + yield TriggerEvent({"status": "error", "message": str(e), "traceback": tb}) + + +class AiiDAMonitorTrigger(BaseTrigger): + """Trigger that executes the AiiDA task_monitor_job function.""" + + def __init__(self, node_pk: int, monitors_pk: int | None = None): + """Initialize the monitor trigger. + + :param node_pk: Primary key of the CalcJobNode to monitor + :param monitors_pk: Primary key of the CalcJobMonitors node (if applicable) + """ + super().__init__() + self.node_pk = node_pk + self.monitors_pk = monitors_pk + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the trigger for persistence.""" + return ( + "airflow_provider_aiida.triggers.async_aiida_calcjob.AiiDAMonitorTrigger", + { + "node_pk": self.node_pk, + "monitors_pk": self.monitors_pk, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Execute the monitor task.""" + try: + # Load AiiDA profile (triggers run in separate process) + from aiida import load_profile + load_profile() + + node = load_node(self.node_pk) + transport_queue = get_transport_queue() + cancellable = InterruptableFuture() + + # Load monitors if provided + from aiida.engine.processes.calcjobs.monitors import CalcJobMonitors + monitors = None + if self.monitors_pk: + monitors_node = load_node(self.monitors_pk) + monitors = CalcJobMonitors(monitors_node) + + monitor_result = await task_monitor_job( + node, transport_queue, cancellable, monitors + ) + + result_dict = {"status": "success"} + if monitor_result: + result_dict["action"] = monitor_result.action + result_dict["message"] = monitor_result.message + result_dict["retrieve"] = monitor_result.retrieve + result_dict["parse"] = monitor_result.parse + + yield TriggerEvent(result_dict) + except Exception as e: + import traceback + tb = traceback.format_exc() + logger.exception(f"Monitor task failed for node {self.node_pk}") + yield TriggerEvent({"status": "error", "message": str(e), "traceback": tb}) + + +class AiiDARetrieveTrigger(BaseTrigger): + """Trigger that executes the AiiDA task_retrieve_job function.""" + + def __init__(self, node_pk: int, retrieved_temporary_folder: str): + """Initialize the retrieve trigger. + + :param node_pk: Primary key of the CalcJobNode to retrieve + :param retrieved_temporary_folder: Path to temporary folder for retrieved files + """ + super().__init__() + self.node_pk = node_pk + self.retrieved_temporary_folder = retrieved_temporary_folder + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the trigger for persistence.""" + return ( + "airflow_provider_aiida.triggers.async_aiida_calcjob.AiiDARetrieveTrigger", + { + "node_pk": self.node_pk, + "retrieved_temporary_folder": self.retrieved_temporary_folder, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Execute the retrieve task.""" + try: + # Load AiiDA profile (triggers run in separate process) + from aiida import load_profile + load_profile() + + node = load_node(self.node_pk) + transport_queue = get_transport_queue() + cancellable = InterruptableFuture() + + # Create the retrieved_temporary_folder if it doesn't exist + from pathlib import Path + Path(self.retrieved_temporary_folder).mkdir(parents=True, exist_ok=True) + + retrieved = await task_retrieve_job( + node, transport_queue, self.retrieved_temporary_folder, cancellable + ) + + yield TriggerEvent({ + "status": "success", + "retrieved": retrieved is not None, + }) + except Exception as e: + import traceback + tb = traceback.format_exc() + logger.exception(f"Retrieve task failed for node {self.node_pk}") + yield TriggerEvent({"status": "error", "message": str(e), "traceback": tb}) + + +class AiiDAStashTrigger(BaseTrigger): + """Trigger that executes the AiiDA task_stash_job function.""" + + def __init__(self, node_pk: int): + """Initialize the stash trigger. + + :param node_pk: Primary key of the CalcJobNode to stash + """ + super().__init__() + self.node_pk = node_pk + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the trigger for persistence.""" + return ( + "airflow_provider_aiida.triggers.async_aiida_calcjob.AiiDAStashTrigger", + {"node_pk": self.node_pk}, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Execute the stash task.""" + try: + # Load AiiDA profile (triggers run in separate process) + from aiida import load_profile + load_profile() + + node = load_node(self.node_pk) + transport_queue = get_transport_queue() + cancellable = InterruptableFuture() + + await task_stash_job(node, transport_queue, cancellable) + + yield TriggerEvent({"status": "success"}) + except Exception as e: + import traceback + tb = traceback.format_exc() + logger.exception(f"Stash task failed for node {self.node_pk}") + yield TriggerEvent({"status": "error", "message": str(e), "traceback": tb}) + + +class AiiDAUnstashTrigger(BaseTrigger): + """Trigger that executes the AiiDA task_unstash_job function.""" + + def __init__(self, node_pk: int): + """Initialize the unstash trigger. + + :param node_pk: Primary key of the CalcJobNode to unstash + """ + super().__init__() + self.node_pk = node_pk + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the trigger for persistence.""" + return ( + "airflow_provider_aiida.triggers.async_aiida_calcjob.AiiDAUnstashTrigger", + {"node_pk": self.node_pk}, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Execute the unstash task.""" + try: + # Load AiiDA profile (triggers run in separate process) + from aiida import load_profile + load_profile() + + node = load_node(self.node_pk) + transport_queue = get_transport_queue() + cancellable = InterruptableFuture() + + await task_unstash_job(node, transport_queue, cancellable) + + yield TriggerEvent({"status": "success"}) + except Exception as e: + import traceback + tb = traceback.format_exc() + logger.exception(f"Unstash task failed for node {self.node_pk}") + yield TriggerEvent({"status": "error", "message": str(e), "traceback": tb}) + + +class AiiDAKillTrigger(BaseTrigger): + """Trigger that executes the AiiDA task_kill_job function.""" + + def __init__(self, node_pk: int): + """Initialize the kill trigger. + + :param node_pk: Primary key of the CalcJobNode to kill + """ + super().__init__() + self.node_pk = node_pk + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the trigger for persistence.""" + return ( + "airflow_provider_aiida.triggers.async_aiida_calcjob.AiiDAKillTrigger", + {"node_pk": self.node_pk}, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Execute the kill task.""" + try: + # Load AiiDA profile (triggers run in separate process) + from aiida import load_profile + load_profile() + + node = load_node(self.node_pk) + transport_queue = get_transport_queue() + cancellable = InterruptableFuture() + + result = await task_kill_job(node, transport_queue, cancellable) + + yield TriggerEvent({ + "status": "success", + "killed": result, + }) + except Exception as e: + import traceback + tb = traceback.format_exc() + logger.exception(f"Kill task failed for node {self.node_pk}") + yield TriggerEvent({"status": "error", "message": str(e), "traceback": tb})