Skip to content
43 changes: 43 additions & 0 deletions run_arithmetic_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from pathlib import Path
import os
from airflow.api.client.local_client import Client

# Set AIRFLOW__CORE__DAGS_FOLDER to include example_dags
dag_folder = str(
Path(__file__).parent / "src" / "airflow_provider_aiida" / "example_dags"
)
os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = dag_folder

# Import AFTER setting the environment variable
from airflow.models import DagBag

# Create directories
Path("/tmp/airflow/local_workdir").mkdir(parents=True, exist_ok=True)
Path("/tmp/airflow/remote_workdir").mkdir(parents=True, exist_ok=True)

# Configuration
conf = {
"machine": "localhost",
"local_workdir": "/tmp/airflow/local_workdir",
"remote_workdir": "/tmp/airflow/remote_workdir",
"add_x": 10,
"add_y": 5,
"multiply_x": 7,
"multiply_y": 3,
}

# Run DAG using test mode (bypasses serialization requirement)
dagbag = DagBag(dag_folder=dag_folder, include_examples=False)
dag = dagbag.get_dag("arithmetic_add_multiply")

# Use test mode with execution_date to avoid serialization issues

# dag.test(
# run_conf=conf,
# # execution_date=datetime.now(),
# use_executor=False, # Run tasks sequentially in the same process
# )

# Trigger DAG using API client (requires scheduler to be running)
client: Client = Client()
client.trigger_dag(dag_id="arithmetic_add_multiply", conf=conf)
56 changes: 49 additions & 7 deletions src/airflow_provider_aiida/example_dags/arithmetic_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
class AddJobTaskGroup(CalcJobTaskGroup):
"""Addition job task group - directly IS a TaskGroup"""

# Define AiiDA input/output port names (like in aiida-core CalcJob.define())
# AIIDA_INPUT_PORTS = ['x', 'y']
# AIIDA_OUTPUT_PORTS = ['sum']
Comment on lines +14 to +16
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed i think


def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workdir: str,
x: int, y: int, sleep: int, **kwargs):
self.x = x
Expand All @@ -20,11 +24,20 @@ def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workd

def prepare(self, **context) -> Dict[str, Any]:
"""Prepare addition job inputs"""
# Resolve template variables from params
from airflow.models import TaskInstance
ti: TaskInstance = context['task_instance']
params = context['params']

x = params['add_x']
y = params['add_y']
sleep = 3 # or get from params if needed

to_upload_files = {}

submission_script = f"""
sleep {self.sleep}
echo "$(({self.x}+{self.y}))" > result.out
sleep {sleep}
echo "$(({x}+{y}))" > result.out
"""

to_receive_files = {"result.out": "addition_result.txt"}
Expand All @@ -34,6 +47,9 @@ def prepare(self, **context) -> Dict[str, Any]:
context['task_instance'].xcom_push(key='submission_script', value=submission_script)
context['task_instance'].xcom_push(key='to_receive_files', value=to_receive_files)

# Push AiiDA inputs for provenance (matches AIIDA_INPUT_PORTS)
context['task_instance'].xcom_push(key='aiida_inputs', value={'x': x, 'y': y})

return {
"to_upload_files": to_upload_files,
"submission_script": submission_script,
Expand All @@ -59,7 +75,7 @@ def parse(self, local_workdir: str, **context) -> tuple[int, Dict[str, Any]]:
continue

result_content = file_path.read_text().strip()
print(f"Addition result ({self.x} + {self.y}): {result_content}")
print(f"Addition result: {result_content}")
results[file_key] = int(result_content)

except Exception as e:
Expand All @@ -69,12 +85,21 @@ def parse(self, local_workdir: str, **context) -> tuple[int, Dict[str, Any]]:
# Store both exit status and results in XCom
final_result = (exit_status, results)
context['task_instance'].xcom_push(key='final_result', value=final_result)

# Push AiiDA outputs for provenance (matches AIIDA_OUTPUT_PORTS)
if 'result.out' in results:
context['task_instance'].xcom_push(key='aiida_outputs', value={'sum': results['result.out']})

return final_result


class MultiplyJobTaskGroup(CalcJobTaskGroup):
"""Multiplication job task group - directly IS a TaskGroup"""

# Define AiiDA input/output port names (like in aiida-core CalcJob.define())
# AIIDA_INPUT_PORTS = ['x', 'y']
# AIIDA_OUTPUT_PORTS = ['result']

def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workdir: str,
x: int, y: int, sleep: int, **kwargs):
self.x = x
Expand All @@ -84,12 +109,21 @@ def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workd

def prepare(self, **context) -> Dict[str, Any]:
"""Prepare multiplication job inputs"""
# Resolve template variables from params
from airflow.models import TaskInstance
ti: TaskInstance = context['task_instance']
params = context['params']

x = params['multiply_x']
y = params['multiply_y']
sleep = 2 # or get from params if needed

to_upload_files = {}

submission_script = f"""
sleep {self.sleep}
echo "$(({self.x}*{self.y}))" > multiply_result.out
echo "Operation: {self.x} * {self.y}" > operation.log
sleep {sleep}
echo "$(({x}*{y}))" > multiply_result.out
echo "Operation: {x} * {y}" > operation.log
"""

to_receive_files = {
Expand All @@ -102,6 +136,9 @@ def prepare(self, **context) -> Dict[str, Any]:
context['task_instance'].xcom_push(key='submission_script', value=submission_script)
context['task_instance'].xcom_push(key='to_receive_files', value=to_receive_files)

# Push AiiDA inputs for provenance (matches AIIDA_INPUT_PORTS)
context['task_instance'].xcom_push(key='aiida_inputs', value={'x': x, 'y': y})

return {
"to_upload_files": to_upload_files,
"submission_script": submission_script,
Expand Down Expand Up @@ -146,6 +183,11 @@ def parse(self, local_workdir: str, **context) -> tuple[int, Dict[str, Any]]:
# Store both exit status and results in XCom
final_result = (exit_status, results)
context['task_instance'].xcom_push(key='final_result', value=final_result)

# Push AiiDA outputs for provenance (matches AIIDA_OUTPUT_PORTS)
if 'result' in results:
context['task_instance'].xcom_push(key='aiida_outputs', value={'result': results['result']})

return final_result


Expand Down Expand Up @@ -238,4 +280,4 @@ def combine_results():

# Direct usage - add_job and multiply_job ARE TaskGroups!
combine_task = combine_results()
[add_job, multiply_job] >> combine_task
[add_job, multiply_job] >> combine_task
Loading