diff --git a/run_arithmetic_dag.py b/run_arithmetic_dag.py new file mode 100644 index 0000000..da11ed1 --- /dev/null +++ b/run_arithmetic_dag.py @@ -0,0 +1,43 @@ +from pathlib import Path +import os +from airflow.api.client.local_client import Client + +# Set AIRFLOW__CORE__DAGS_FOLDER to include example_dags +dag_folder = str( + Path(__file__).parent / "src" / "airflow_provider_aiida" / "example_dags" +) +os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = dag_folder + +# Import AFTER setting the environment variable +from airflow.models import DagBag + +# Create directories +Path("/tmp/airflow/local_workdir").mkdir(parents=True, exist_ok=True) +Path("/tmp/airflow/remote_workdir").mkdir(parents=True, exist_ok=True) + +# Configuration +conf = { + "machine": "localhost", + "local_workdir": "/tmp/airflow/local_workdir", + "remote_workdir": "/tmp/airflow/remote_workdir", + "add_x": 10, + "add_y": 5, + "multiply_x": 7, + "multiply_y": 3, +} + +# Run DAG using test mode (bypasses serialization requirement) +dagbag = DagBag(dag_folder=dag_folder, include_examples=False) +dag = dagbag.get_dag("arithmetic_add_multiply") + +# Use test mode with execution_date to avoid serialization issues + +# dag.test( +# run_conf=conf, +# # execution_date=datetime.now(), +# use_executor=False, # Run tasks sequentially in the same process +# ) + +# Trigger DAG using API client (requires scheduler to be running) +client: Client = Client() +client.trigger_dag(dag_id="arithmetic_add_multiply", conf=conf) diff --git a/src/airflow_provider_aiida/example_dags/arithmetic_add.py b/src/airflow_provider_aiida/example_dags/arithmetic_add.py index a7f045b..811255e 100644 --- a/src/airflow_provider_aiida/example_dags/arithmetic_add.py +++ b/src/airflow_provider_aiida/example_dags/arithmetic_add.py @@ -11,6 +11,10 @@ class AddJobTaskGroup(CalcJobTaskGroup): """Addition job task group - directly IS a TaskGroup""" + # Define AiiDA input/output port names (like in aiida-core CalcJob.define()) + # AIIDA_INPUT_PORTS = ['x', 'y'] + # AIIDA_OUTPUT_PORTS = ['sum'] + def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workdir: str, x: int, y: int, sleep: int, **kwargs): self.x = x @@ -20,11 +24,20 @@ def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workd def prepare(self, **context) -> Dict[str, Any]: """Prepare addition job inputs""" + # Resolve template variables from params + from airflow.models import TaskInstance + ti: TaskInstance = context['task_instance'] + params = context['params'] + + x = params['add_x'] + y = params['add_y'] + sleep = 3 # or get from params if needed + to_upload_files = {} submission_script = f""" -sleep {self.sleep} -echo "$(({self.x}+{self.y}))" > result.out +sleep {sleep} +echo "$(({x}+{y}))" > result.out """ to_receive_files = {"result.out": "addition_result.txt"} @@ -34,6 +47,9 @@ def prepare(self, **context) -> Dict[str, Any]: context['task_instance'].xcom_push(key='submission_script', value=submission_script) context['task_instance'].xcom_push(key='to_receive_files', value=to_receive_files) + # Push AiiDA inputs for provenance (matches AIIDA_INPUT_PORTS) + context['task_instance'].xcom_push(key='aiida_inputs', value={'x': x, 'y': y}) + return { "to_upload_files": to_upload_files, "submission_script": submission_script, @@ -59,7 +75,7 @@ def parse(self, local_workdir: str, **context) -> tuple[int, Dict[str, Any]]: continue result_content = file_path.read_text().strip() - print(f"Addition result ({self.x} + {self.y}): {result_content}") + print(f"Addition result: {result_content}") results[file_key] = int(result_content) except Exception as e: @@ -69,12 +85,21 @@ def parse(self, local_workdir: str, **context) -> tuple[int, Dict[str, Any]]: # Store both exit status and results in XCom final_result = (exit_status, results) context['task_instance'].xcom_push(key='final_result', value=final_result) + + # Push AiiDA outputs for provenance (matches AIIDA_OUTPUT_PORTS) + if 'result.out' in results: + context['task_instance'].xcom_push(key='aiida_outputs', value={'sum': results['result.out']}) + return final_result class MultiplyJobTaskGroup(CalcJobTaskGroup): """Multiplication job task group - directly IS a TaskGroup""" + # Define AiiDA input/output port names (like in aiida-core CalcJob.define()) + # AIIDA_INPUT_PORTS = ['x', 'y'] + # AIIDA_OUTPUT_PORTS = ['result'] + def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workdir: str, x: int, y: int, sleep: int, **kwargs): self.x = x @@ -84,12 +109,21 @@ def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workd def prepare(self, **context) -> Dict[str, Any]: """Prepare multiplication job inputs""" + # Resolve template variables from params + from airflow.models import TaskInstance + ti: TaskInstance = context['task_instance'] + params = context['params'] + + x = params['multiply_x'] + y = params['multiply_y'] + sleep = 2 # or get from params if needed + to_upload_files = {} submission_script = f""" -sleep {self.sleep} -echo "$(({self.x}*{self.y}))" > multiply_result.out -echo "Operation: {self.x} * {self.y}" > operation.log +sleep {sleep} +echo "$(({x}*{y}))" > multiply_result.out +echo "Operation: {x} * {y}" > operation.log """ to_receive_files = { @@ -102,6 +136,9 @@ def prepare(self, **context) -> Dict[str, Any]: context['task_instance'].xcom_push(key='submission_script', value=submission_script) context['task_instance'].xcom_push(key='to_receive_files', value=to_receive_files) + # Push AiiDA inputs for provenance (matches AIIDA_INPUT_PORTS) + context['task_instance'].xcom_push(key='aiida_inputs', value={'x': x, 'y': y}) + return { "to_upload_files": to_upload_files, "submission_script": submission_script, @@ -146,6 +183,11 @@ def parse(self, local_workdir: str, **context) -> tuple[int, Dict[str, Any]]: # Store both exit status and results in XCom final_result = (exit_status, results) context['task_instance'].xcom_push(key='final_result', value=final_result) + + # Push AiiDA outputs for provenance (matches AIIDA_OUTPUT_PORTS) + if 'result' in results: + context['task_instance'].xcom_push(key='aiida_outputs', value={'result': results['result']}) + return final_result @@ -238,4 +280,4 @@ def combine_results(): # Direct usage - add_job and multiply_job ARE TaskGroups! combine_task = combine_results() - [add_job, multiply_job] >> combine_task \ No newline at end of file + [add_job, multiply_job] >> combine_task diff --git a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py index 9aaae8a..93ec6b8 100644 --- a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py +++ b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py @@ -1,210 +1,475 @@ -import sqlite3 -import os import logging -from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Optional + +from airflow.models import DagRun, TaskInstance from airflow.plugins_manager import AirflowPlugin +from airflow.sdk.definitions.param import Param +from airflow.models import Param as ModelsParam from airflow.listeners import hookimpl -from airflow.models import DagRun, XCom -from airflow.utils.state import DagRunState -from sqlalchemy import inspect as sqlalchemy_inspect +from aiida import load_profile, orm +from aiida.common.links import LinkType +import json + +load_profile() logger = logging.getLogger(__name__) -# Database path -DB_PATH = os.path.join(os.path.dirname(__file__), 'dagrun_tracking.db') -def _init_database(): - """Initialize the SQLite database and create the dagrun table if it doesn't exist.""" - try: - with sqlite3.connect(DB_PATH) as conn: - cursor = conn.cursor() - cursor.execute(''' - CREATE TABLE IF NOT EXISTS dagrun_events ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - dag_id TEXT NOT NULL, - run_id TEXT NOT NULL, - run_type TEXT, - state TEXT, - execution_date TEXT, - start_date TEXT, - end_date TEXT, - external_trigger BOOLEAN, - conf TEXT, - dag_output TEXT, - event_type TEXT NOT NULL, - event_timestamp TEXT NOT NULL, - created_at TEXT DEFAULT CURRENT_TIMESTAMP +def _param_to_python(param) -> Any: + """ + Convert an Airflow Param object to a Python native value. + + Args: + param: Airflow Param object or any other value + + Returns: + Python native value (int, float, bool, str, dict, list, etc.) + """ + # Check if it's a Param object + if not isinstance(param, (Param, ModelsParam)): + return param + + # Get the actual value + actual_value = param.value + + # Get schema type if available + schema = getattr(param, "schema", {}) + param_type = schema.get("type", None) + + # Convert based on schema type + if param_type == "integer": + try: + return int(actual_value) + except (ValueError, TypeError): + logger.warning(f"Could not convert Param value '{actual_value}' to int") + return actual_value + elif param_type == "number": + try: + return float(actual_value) + except (ValueError, TypeError): + logger.warning(f"Could not convert Param value '{actual_value}' to float") + return actual_value + elif param_type == "boolean": + if isinstance(actual_value, bool): + return actual_value + if isinstance(actual_value, str): + return actual_value.lower() in ("true", "1", "yes", "on") + return bool(actual_value) + elif param_type == "string": + return str(actual_value) + elif param_type == "object": + return actual_value if isinstance(actual_value, dict) else {} + elif param_type == "array": + return actual_value if isinstance(actual_value, (list, tuple)) else [] + else: + return actual_value + + +def _convert_to_aiida_data(value: Any) -> Optional[orm.Data]: + """ + Convert a Python value to the appropriate AiiDA Data node. + + Returns None if the value type is not supported or should be skipped. + """ + # First check if it's an Airflow Param and convert it + if isinstance(value, (Param, ModelsParam)): + value = _param_to_python(value) + + # Handle basic types (check bool BEFORE int, since bool is subclass of int) + if isinstance(value, bool): + return orm.Bool(value) + elif isinstance(value, int): + return orm.Int(value) + elif isinstance(value, float): + return orm.Float(value) + elif isinstance(value, str): + return orm.Str(value) + # Handle collections - store as Dict or List nodes + elif isinstance(value, dict): + return orm.Dict(dict=value) + elif isinstance(value, (list, tuple)): + return orm.List(list=list(value)) + # Handle Path objects + elif isinstance(value, Path): + return orm.Str(str(value)) + # For complex objects, try JSON serialization + else: + try: + json_str = json.dumps(value) + return orm.Str(json_str) + except (TypeError, ValueError): + logger.warning( + f"Could not convert value of type {type(value)} to AiiDA node" + ) + return None + + +def _store_dag_inputs_in_aiida( + node: orm.Node, params: Dict[str, Any], prefix: str = "" +) -> None: + """ + Store parameters as AiiDA data nodes and link them as inputs. + + Args: + node: The AiiDA node to link inputs to + params: Dictionary of parameters to store + prefix: Optional prefix for link labels (e.g., 'dag_params', 'conf') + """ + for key, value in params.items(): + # Create link label with optional prefix + link_label = f"{prefix}_{key}" if prefix else key + + # Skip None values + if value is None: + continue + + # Convert to AiiDA data node + aiida_data = _convert_to_aiida_data(value) + if isinstance(node, orm.WorkflowNode): + link_type = LinkType.INPUT_WORK + elif isinstance(node, orm.CalculationNode): + link_type = LinkType.INPUT_CALC + + if aiida_data is not None: + try: + # Store the data node first + aiida_data.store() + # Then add the link + node.base.links.add_incoming( + aiida_data, link_type=link_type, link_label=link_label ) - ''') - conn.commit() - logger.info(f"DagRun tracking database initialized at {DB_PATH}") - except Exception as e: - logger.error(f"Failed to initialize database: {e}") + except ValueError as e: + # Link already exists or other constraint violation + logger.debug(f"Could not link {link_label}: {e}") + + +def should_create_calcjob_node_for_taskgroup(task_instance: TaskInstance) -> bool: + """ + Determine if a task instance is part of a CalcJobTaskGroup. + + This checks if the task is the "parse" task of a CalcJobTaskGroup, + which signals completion of the entire group. + + Args: + task_instance: Airflow task instance + + Returns: + bool: True if this is a parse task from a CalcJobTaskGroup + """ + # Check if task_id indicates it's a parse task in a task group + if ".parse" in task_instance.task_id: + # Verify parent group exists and has the expected structure + group_id = task_instance.task_id.rsplit(".parse", 1)[0] + + # Get task instances from the dag_run + # Note: In the success hook, we have access to the full dag_run + from airflow import settings + + session = settings.Session() + dag_run = ( + session.query(DagRun) + .filter( + DagRun.dag_id == task_instance.dag_id, + DagRun.run_id == task_instance.run_id, + ) + .first() + ) -def _get_dag_output(dagrun: DagRun) -> str: - """Retrieve DAG output from XCom if available.""" - try: - # Look for XCom with key 'dag_output' from any task in this DAG run - from airflow.models import XCom - from airflow.utils.session import provide_session - - @provide_session - def _query_xcom(session=None): - xcom_value = session.query(XCom).filter( - XCom.dag_id == dagrun.dag_id, - XCom.run_id == dagrun.run_id, - XCom.key == 'dag_output' - ).first() - return xcom_value.value if xcom_value else None - - output = _query_xcom() - logger.info(f"[DEBUG] Retrieved DAG output: {output}") - return str(output) if output else '{}' + if dag_run: + task_instances = dag_run.get_task_instances() + # Look for the prepare task in the same group + for ti in task_instances: + if ti.task_id == f"{group_id}.prepare": + return True - except Exception as e: - logger.warning(f"Failed to retrieve DAG output: {e}") - return '{}' + return False -def _get_dag_output_safe(dag_id: str, run_id: str) -> str: - """Safely retrieve DAG output from XCom using dag_id and run_id strings.""" - try: - # Look for XCom with key 'dag_output' from any task in this DAG run - from airflow.models import XCom - from airflow.utils.session import provide_session - - @provide_session - def _query_xcom(session=None): - xcom_value = session.query(XCom).filter( - XCom.dag_id == dag_id, - XCom.run_id == run_id, - XCom.key == 'dag_output' - ).first() - return xcom_value.value if xcom_value else None - - output = _query_xcom() - logger.info(f"[DEBUG] Retrieved DAG output: {output}") - return str(output) if output else '{}' - except Exception as e: - logger.warning(f"Failed to retrieve DAG output: {e}") - return '{}' +def _get_taskgroup_id_from_parse_task(task_instance: TaskInstance) -> str: + """Extract the task group ID from a parse task's task_id""" + return task_instance.task_id.rsplit(".parse", 1)[0] -def _store_dagrun_event(dagrun: DagRun, event_type: str): - """Store dagrun event information to SQLite database.""" + +def _store_taskgroup_inputs( + node: orm.CalcJobNode, task_instance: TaskInstance, dag_run: DagRun +) -> None: + """ + Store all inputs for a CalcJobTaskGroup. + + Inputs should be explicitly stored by the prepare task in XCom with key 'aiida_inputs'. + This allows each TaskGroup to define its own input structure. + + Args: + node: The CalcJobNode to link inputs to + task_instance: The parse task instance (end of the group) + dag_run: The DAG run containing the task + """ + group_id = _get_taskgroup_id_from_parse_task(task_instance) + prepare_task_id = f"{group_id}.prepare" + + # Try to get inputs explicitly defined by the prepare task try: - # IMPORTANT: Extract all needed attributes FIRST to avoid DetachedInstanceError - # Use SQLAlchemy inspect to check if attributes are loaded - inspector = sqlalchemy_inspect(dagrun) + aiida_inputs = task_instance.xcom_pull( + task_ids=prepare_task_id, key="aiida_inputs" + ) + if aiida_inputs and isinstance(aiida_inputs, dict): + _store_dag_inputs_in_aiida(node, aiida_inputs, prefix="") + return + except Exception as e: + logger.debug(f"Could not retrieve aiida_inputs from prepare task: {e}") - # Check which attributes are loaded to avoid triggering lazy loads - unloaded = inspector.unloaded + # If no explicit inputs provided, log a warning + logger.warning( + f"No 'aiida_inputs' found in XCom for {prepare_task_id}. " + f"CalcJobTaskGroup should push a dict with key 'aiida_inputs' containing input data." + ) - # Access ALL attributes while still in session context - dag_id = dagrun.dag_id - run_id = dagrun.run_id - # Safely extract run_type - only access if already loaded - if 'run_type' not in unloaded: - run_type = dagrun.run_type - run_type_str = run_type.value if hasattr(run_type, 'value') else str(run_type) if run_type else None - else: - run_type_str = None +def _store_taskgroup_outputs( + node: orm.CalcJobNode, task_instance: TaskInstance +) -> None: + """ + Store all outputs from a CalcJobTaskGroup. + + Outputs should be explicitly stored by the parse task in XCom with key 'aiida_outputs'. + This allows each TaskGroup to define its own output structure. + + Args: + node: The CalcJobNode to link outputs to + task_instance: The parse task instance + """ + try: + # Try to get outputs explicitly defined by the parse task + aiida_outputs = task_instance.xcom_pull( + task_ids=task_instance.task_id, key="aiida_outputs" + ) + + if aiida_outputs and isinstance(aiida_outputs, dict): + for key, value in aiida_outputs.items(): + aiida_data = _convert_to_aiida_data(value) + if aiida_data: + aiida_data.store() + aiida_data.base.links.add_incoming( + node, link_type=LinkType.CREATE, link_label=key + ) + return - # Safely extract state - only access if already loaded to avoid session refresh - if '_state' not in unloaded: - state = dagrun._state - state_str = state.value if hasattr(state, 'value') else str(state) if state else None + except Exception as e: + logger.debug(f"Could not retrieve aiida_outputs from parse task: {e}") + + # If no explicit outputs provided, log a warning + logger.warning( + f"No 'aiida_outputs' found in XCom for {task_instance.task_id}. " + f"CalcJobTaskGroup parse method should push a dict with key 'aiida_outputs' containing output data." + ) + + +def _create_calcjob_node_from_taskgroup( + task_instance: TaskInstance, + parent_workchain_node: orm.WorkChainNode, + dag_run: DagRun, +) -> orm.CalcJobNode: + """ + Create an AiiDA CalcJobNode from a CalcJobTaskGroup (represented by its parse task). + + Args: + task_instance: The parse task instance (end of the TaskGroup) + parent_workchain_node: The parent WorkChainNode for the DAG + dag_run: The DAG run + + Returns: + The created and stored CalcJobNode + """ + group_id = _get_taskgroup_id_from_parse_task(task_instance) + + cj_node: orm.CalcJobNode = orm.CalcJobNode() + cj_node.label = group_id + cj_node.description = f"CalcJob from Airflow TaskGroup {group_id}" + + # Store Airflow metadata in extras + cj_node.base.extras.set("airflow_dag_id", task_instance.dag_id) + cj_node.base.extras.set("airflow_run_id", task_instance.run_id) + cj_node.base.extras.set("airflow_task_group_id", group_id) + + # Set process type to the group ID + cj_node.set_process_type(group_id) + cj_node.set_process_state("finished") + cj_node.set_process_label("AirflowCalcJob") + + # Determine exit status from parse task result + exit_status = 0 + try: + final_result = task_instance.xcom_pull( + task_ids=task_instance.task_id, key="final_result" + ) + if isinstance(final_result, tuple) and len(final_result) == 2: + exit_status = final_result[0] + except Exception: + pass + + cj_node.set_exit_status(exit_status if task_instance.state == "success" else 1) + + # Link to parent WorkChainNode (before storing) + if parent_workchain_node: + cj_node.base.links.add_incoming( + parent_workchain_node, + link_type=LinkType.CALL_CALC, + link_label=group_id, + ) + + # Add inputs BEFORE storing the node + _store_taskgroup_inputs(cj_node, task_instance, dag_run) + + # Now store the node (inputs are locked in) + cj_node.store() + + # Outputs can be added after storing + _store_taskgroup_outputs(cj_node, task_instance) + + logger.info(f"Created CalcJobNode {cj_node.pk} for TaskGroup {group_id}") + return cj_node + + +def _should_integrate_dag_with_aiida(dag_run: DagRun) -> bool: + """Check if this DAG should be stored in AiiDA""" + dag_tags = getattr(dag_run.dag, "tags", []) + # Look for tags that indicate this is a CalcJob workflow + return any(tag in dag_tags for tag in ["aiida", "calcjob", "taskgroup"]) + + +def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: + """ + Create a WorkChainNode from a running Airflow DAG and store its inputs. + + Returns: + The created and stored WorkChainNode + """ + wc_node: orm.WorkChainNode = orm.WorkChainNode() + wc_node.label = dag_run.dag_id + wc_node.description = f"Workflow from Airflow DAG {dag_run.dag_id}" + + wc_node.base.extras.set("airflow_dag_id", dag_run.dag_id) + wc_node.base.extras.set("airflow_run_id", dag_run.run_id) + + # Set process type to the DAG ID + wc_node.set_process_type(dag_run.dag_id) + wc_node.set_process_label("AirflowWorkChain") + + # Store DAG parameters + # Use conf if available, otherwise use default params + dag_conf = getattr(dag_run, "conf", {}) + dag_params = getattr(dag_run.dag, "params", {}) + + # Prefer conf values (runtime overrides), fall back to default params + params_to_store = {} + for key, param in dag_params.items(): + # Get actual value from conf or use default + if dag_conf and key in dag_conf: + params_to_store[key] = dag_conf[key] else: - state_str = None - - # Safely extract other attributes - execution_date = dagrun.execution_date if 'execution_date' not in unloaded else None - start_date = dagrun.start_date if 'start_date' not in unloaded else None - end_date = dagrun.end_date if 'end_date' not in unloaded else None - external_trigger = dagrun.external_trigger if 'external_trigger' not in unloaded else False - conf = dagrun.conf if 'conf' not in unloaded else {} - - logger.info(f"[DEBUG] Starting to store {event_type} event") - logger.info(f"[DEBUG] DB_PATH: {DB_PATH}") - logger.info(f"[DEBUG] dag_id: {dag_id}") - logger.info(f"[DEBUG] run_id: {run_id}") - - with sqlite3.connect(DB_PATH) as conn: - cursor = conn.cursor() - - # Test if table exists - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='dagrun_events'") - table_exists = cursor.fetchone() - logger.info(f"[DEBUG] Table exists: {table_exists is not None}") - - # Get DAG output for completed DAGs (pass extracted values instead of dagrun object) - dag_output = _get_dag_output_safe(dag_id, run_id) if event_type in ['success', 'failed'] else '{}' - - data_tuple = ( - dag_id, - run_id, - run_type_str, - state_str, - execution_date.isoformat() if execution_date else None, - start_date.isoformat() if start_date else None, - end_date.isoformat() if end_date else None, - external_trigger, - str(conf) if conf else '{}', - dag_output, - event_type, - datetime.now().isoformat() - ) + params_to_store[key] = _param_to_python(param) - logger.info(f"[DEBUG] Data tuple: {data_tuple}") + # Store with clean names (no prefix) + _store_dag_inputs_in_aiida(wc_node, params_to_store, prefix="") - cursor.execute(''' - INSERT INTO dagrun_events ( - dag_id, run_id, run_type, state, execution_date, - start_date, end_date, external_trigger, conf, dag_output, - event_type, event_timestamp, created_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', data_tuple + (datetime.now().strftime('%Y-%m-%d %H:%M:%S'),)) + wc_node.set_process_state("running") + wc_node.store() - conn.commit() - logger.info(f"[SUCCESS] Stored {event_type} event for DAG run {dag_id}/{run_id}") + logger.info(f"Created WorkChainNode {wc_node.pk} for DAG {dag_run.dag_id}") + return wc_node - # Verify insertion - cursor.execute("SELECT COUNT(*) FROM dagrun_events") - count = cursor.fetchone()[0] - logger.info(f"[DEBUG] Total events in DB: {count}") - except Exception as e: - logger.error(f"[ERROR] Failed to store dagrun event: {e}") - import traceback - logger.error(f"[ERROR] Traceback: {traceback.format_exc()}") +def _finalize_workchain_node(wc_node: orm.WorkChainNode, dag_run: DagRun) -> None: + """ + Finalize the WorkChainNode and create CalcJobNodes for all completed task groups. + + Args: + wc_node: The WorkChainNode to finalize + dag_run: The completed DAG run + """ + # Update process state to finished + wc_node.set_process_state("finished") + wc_node.set_exit_status(0) -# Initialize database on import -_init_database() + # Process each task in the DAG to find CalcJobTaskGroup parse tasks + task_instances = dag_run.get_task_instances() + for ti in task_instances: + if ti.state == "success" and should_create_calcjob_node_for_taskgroup(ti): + _create_calcjob_node_from_taskgroup(ti, wc_node, dag_run) -class DagRunListener: - """Class-based DAG run listener.""" + logger.info(f"Finalized WorkChainNode {wc_node.pk} for DAG {dag_run.dag_id}") + + +# Airflow Listener Plugin +class AiiDATaskGroupIntegrationListener: + """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" @hookimpl def on_dag_run_running(self, dag_run: DagRun, msg: str): - """Called when a DAG run enters the running state.""" - logger.info(f"[CLASS LISTENER] DAG run started: {dag_run.dag_id}/{dag_run.run_id}") - _store_dagrun_event(dag_run, 'running') + breakpoint() @hookimpl def on_dag_run_success(self, dag_run: DagRun, msg: str): - """Called when a DAG run completes successfully.""" - logger.info(f"[CLASS LISTENER] DAG run succeeded: {dag_run.dag_id}/{dag_run.run_id}") - _store_dagrun_event(dag_run, 'success') + """ + Called when a DAG run completes successfully. + + Creates the WorkChainNode with inputs, then creates CalcJobNodes for all + completed task groups, and finally finalizes the WorkChainNode. + """ + logger.info(f"[HOOK] on_dag_run_success: {dag_run.dag_id}/{dag_run.run_id}") + + if not _should_integrate_dag_with_aiida(dag_run): + logger.debug(f"DAG {dag_run.dag_id} not tagged for AiiDA integration") + return + + try: + logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") + wc_node = _create_workchain_node_with_inputs(dag_run) + + logger.info(f"Finalizing WorkChainNode for DAG {dag_run.dag_id}") + _finalize_workchain_node(wc_node, dag_run) + + logger.info( + f"Successfully integrated DAG {dag_run.dag_id} into AiiDA provenance" + ) + except Exception as e: + logger.error( + f"Failed to integrate DAG {dag_run.dag_id} into AiiDA: {e}", + exc_info=True, + ) @hookimpl def on_dag_run_failed(self, dag_run: DagRun, msg: str): - """Called when a DAG run fails.""" - logger.info(f"[CLASS LISTENER] DAG run failed: {dag_run.dag_id}/{dag_run.run_id}") - _store_dagrun_event(dag_run, 'failed') + """ + Called when a DAG run fails. + + Creates a failed WorkChainNode for provenance tracking. + """ + logger.info(f"[HOOK] on_dag_run_failed: {dag_run.dag_id}/{dag_run.run_id}") + + if not _should_integrate_dag_with_aiida(dag_run): + return + + try: + # Create WorkChainNode for failed run + wc_node = _create_workchain_node_with_inputs(dag_run) + wc_node.set_process_state("excepted") + wc_node.set_exit_status(1) + logger.info( + f"Created failed WorkChainNode {wc_node.pk} for DAG {dag_run.dag_id}" + ) + except Exception as e: + logger.error( + f"Failed to create WorkChainNode for failed DAG: {e}", exc_info=True + ) + # Create listener instance -dag_run_listener = DagRunListener() +aiida_taskgroup_listener = AiiDATaskGroupIntegrationListener() + +# Plugin registration class AiidaDagRunListener(AirflowPlugin): name = "aiida_dag_run_listener" - listeners = [dag_run_listener] + listeners = [aiida_taskgroup_listener]