Skip to content

Commit b32f4dc

Browse files
committed
add airflow example
ghstack-source-id: 571c14f Pull Request resolved: #33
1 parent d980846 commit b32f4dc

File tree

4 files changed

+143
-0
lines changed

4 files changed

+143
-0
lines changed

torchrecipes/paved_path/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,6 @@ Note that it will output a URL like "aws_batch://torchx/..." that is used to tra
5858
torchx status "aws_batch://torchx/..."
5959
```
6060

61+
## Pipelines
62+
As your applications getting complicated, you can make them as pipelines, manage and monitor them by frameworks like Airflow, Kubeflow, etc.
63+
* [Airflow example](https://github.com/facebookresearch/recipes/tree/main/torchrecipes/paved_path/airflow)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Airflow example
2+
1. Install and start an Airflow server
3+
```bash
4+
./setup.sh
5+
```
6+
> **_NOTE:_**: Airflow UI can be accessed at http://0.0.0.0:8080 (replace the address with your EC2 instance address for public access). Learn more about airflow from [Quick Start](https://airflow.apache.org/docs/apache-airflow/stable/start/local.html)
7+
8+
2. Create a dag config. See the example in `train_charnn.py`
9+
10+
3. Set `dag_folder` to folder containing the dag config in `~/airflow/airflow.cfg`. Such that Airflow can discover your dag configs.
11+
12+
3. Run a task instance
13+
```bash
14+
airflow tasks run train_charnn train 2022-08-01
15+
```
16+
> **_NOTE:_**: the instance can be monitored in UI: http://0.0.0.0:8080/taskinstance/list
17+
18+
4. Backfill the dag over 2 days
19+
```bash
20+
airflow dags backfill train_charnn --start-date 2022-08-01 --end-date 2022-08-02
21+
```
22+
> **_NOTE:_**: the dag runs can be monitored in UI: http://0.0.0.0:8080/dagrun/list/
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/bin/bash
2+
3+
install_airflow=true
4+
start_local_airflow=true
5+
6+
if [ "$install_airflow" = true ]
7+
then
8+
pip3 install --upgrade pip
9+
sudo apt install libffi-dev
10+
pip3 install setuptools-rust
11+
pip3 install "apache-airflow[celery]==2.3.0" --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.3.0/constraints-3.8.txt"
12+
pip3 install apache-airflow-providers-amazon
13+
pip3 install boto3
14+
fi
15+
16+
# https://airflow.apache.org/docs/apache-airflow/stable/start/local.html
17+
if [ "$start_local_airflow" = true ]
18+
then
19+
airflow standalone
20+
fi
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from datetime import datetime, timedelta
9+
10+
import boto3.session
11+
12+
from airflow import DAG
13+
from airflow.operators.bash import BashOperator
14+
from airflow.operators.python import PythonOperator
15+
from airflow.providers.amazon.aws.hooks.batch_waiters import BatchWaitersHook
16+
17+
REGION = "us-west-2"
18+
JOB_QUEUE = "torchx-gpu"
19+
ECR_URL = os.environ["ECR_URL"]
20+
21+
22+
default_args = {
23+
"depends_on_past": False,
24+
"email": ["[email protected]"],
25+
"email_on_failure": False,
26+
"email_on_retry": False,
27+
"retries": 0,
28+
"retry_delay": timedelta(minutes=5),
29+
}
30+
31+
32+
dag = DAG(
33+
"train_charnn",
34+
default_args=default_args,
35+
description="A DAG to train charnn in AWS Batch",
36+
schedule_interval="@daily",
37+
catchup=False,
38+
start_date=datetime(2022, 8, 1),
39+
tags=["aws_batch"],
40+
)
41+
42+
43+
# This example uses torchx CLI with BashOperator.
44+
# We can also use PythonOperator to achive it.
45+
train = BashOperator(
46+
task_id="train",
47+
bash_command=f"""AWS_DEFAULT_REGION=$REGION \
48+
torchx run --workspace '' -s aws_batch \
49+
-cfg queue={JOB_QUEUE},image_repo={ECR_URL}/charnn dist.ddp \
50+
--script charnn/main.py --image {ECR_URL}/charnn:latest \
51+
--cpu 8 --gpu 2 -j 1x2 --memMB 20480 2>&1 \
52+
| grep -Eo aws_batch://torchx/{JOB_QUEUE}:main-[a-z0-9]+""",
53+
env={
54+
"REGION": REGION,
55+
"JOB_QUEUE": JOB_QUEUE,
56+
"ECR_URL": ECR_URL,
57+
},
58+
append_env=True,
59+
dag=dag,
60+
do_xcom_push=True,
61+
)
62+
63+
64+
def wait_for_batch_job(**context) -> bool:
65+
session = boto3.session.Session()
66+
client = session.client("batch", region_name=REGION)
67+
# XComs are a mechanism that let Tasks talk to each other
68+
# Learn more from https://airflow.apache.org/docs/apache-airflow/stable/concepts/xcoms.html
69+
job = context["ti"].xcom_pull(task_ids="train")
70+
job_desc = job.split("/")[-1]
71+
queue_name, job_name = job_desc.split(":")
72+
job_id = client.list_jobs(
73+
jobQueue=queue_name,
74+
filters=[{"name": "JOB_NAME", "values": [job_name]}],
75+
)["jobSummaryList"][0]["jobId"]
76+
waiter = BatchWaitersHook(region_name=REGION)
77+
try:
78+
waiter.wait_for_job(job_id)
79+
return True
80+
except Exception:
81+
return False
82+
83+
84+
wait_for_job = PythonOperator(
85+
task_id="wait_for_job",
86+
python_callable=wait_for_batch_job,
87+
dag=dag,
88+
)
89+
90+
91+
parse_output = BashOperator(
92+
task_id="parse_output",
93+
bash_command="echo {{ ti.xcom_pull(task_ids='wait_for_job') }}",
94+
dag=dag,
95+
)
96+
97+
98+
train >> wait_for_job >> parse_output

0 commit comments

Comments
 (0)