Skip to content

Commit c10a93c

Browse files
committed
add airflow example
ghstack-source-id: 31127b9 Pull Request resolved: facebookresearch#33
1 parent e5e5ef4 commit c10a93c

File tree

4 files changed

+135
-0
lines changed

4 files changed

+135
-0
lines changed

torchrecipes/paved_path/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,6 @@ Note that it will output a URL like "aws_batch://torchx/..." that is used to tra
5858
torchx status "aws_batch://torchx/..."
5959
```
6060

61+
## Pipelines
62+
As your applications getting complicated, you can make them as pipelines, manage and monitor them by frameworks like Airflow, Kubeflow, etc.
63+
* [Airflow example](https://github.com/facebookresearch/recipes/tree/main/torchrecipes/paved_path/airflow)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Airflow example
2+
1. Install and start an airflow server
3+
```batch
4+
./setup.sh
5+
```
6+
Learn more about airflow from [Quick Start](https://airflow.apache.org/docs/apache-airflow/stable/start/local.html)
7+
8+
2. Create a dag
9+
See the example in `train_charnn.py`
10+
11+
3. Run a task instance
12+
```batch
13+
airflow tasks run train_charnn train 2022-08-01
14+
```
15+
16+
4. Backfill the dag over 2 days
17+
```batch
18+
airflow dags backfill train_charnn --start-date 2022-08-01 --end-date 2022-08-02
19+
```
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/bin/bash
2+
3+
install_airflow=true
4+
start_local_airflow=true
5+
6+
if [ "$install_airflow" = true ]
7+
then
8+
pip3 install --upgrade pip
9+
sudo apt install libffi-dev
10+
pip3 install setuptools-rust
11+
pip3 install "apache-airflow[celery]==2.3.0" --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.3.0/constraints-3.8.txt"
12+
pip3 install apache-airflow-providers-amazon
13+
pip3 install boto3
14+
fi
15+
16+
# https://airflow.apache.org/docs/apache-airflow/stable/start/local.html
17+
if [ "$start_local_airflow" = true ]
18+
then
19+
airflow standalone
20+
fi
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from datetime import datetime, timedelta
9+
10+
import boto3.session
11+
12+
from airflow import DAG
13+
from airflow.operators.bash import BashOperator
14+
from airflow.operators.python import PythonOperator
15+
from airflow.providers.amazon.aws.hooks.batch_waiters import BatchWaitersHook
16+
17+
REGION = "us-west-2"
18+
JOB_QUEUE = "torchx-gpu"
19+
ECR_URL = os.environ["ECR_URL"]
20+
21+
22+
default_args = {
23+
"depends_on_past": False,
24+
"email": ["[email protected]"],
25+
"email_on_failure": False,
26+
"email_on_retry": False,
27+
"retries": 0,
28+
"retry_delay": timedelta(minutes=5),
29+
}
30+
31+
32+
dag = DAG(
33+
"train_charnn",
34+
default_args=default_args,
35+
description="A DAG to train charnn in AWS Batch",
36+
schedule_interval="@daily",
37+
catchup=False,
38+
start_date=datetime(2022, 8, 1),
39+
tags=["aws_batch"],
40+
)
41+
42+
43+
train = BashOperator(
44+
task_id="train",
45+
bash_command="""AWS_DEFAULT_REGION=$REGION \
46+
torchx run --workspace '' -s aws_batch \
47+
-cfg queue=$JOB_QUEUE,image_repo=$ECR_URL/charnn dist.ddp \
48+
--script charnn/main.py --image $ECR_URL/charnn:latest \
49+
--cpu 8 --gpu 2 -j 1x2 --memMB 20480 2>&1 \
50+
| grep -Eo aws_batch://torchx/$JOB_QUEUE:main-[a-z0-9]+""",
51+
env={
52+
"REGION": REGION,
53+
"JOB_QUEUE": JOB_QUEUE,
54+
"ECR_URL": ECR_URL,
55+
},
56+
dag=dag,
57+
do_xcom_push=True,
58+
)
59+
60+
61+
def wait_for_batch_job(**context) -> bool:
62+
session = boto3.session.Session()
63+
client = session.client("batch", region_name=REGION)
64+
job = context["ti"].xcom_pull(task_ids="train")
65+
job_desc = job.split("/")[-1]
66+
queue_name, job_name = job_desc.split(":")
67+
job_id = client.list_jobs(
68+
jobQueue=queue_name,
69+
filters=[{"name": "JOB_NAME", "values": [job_name]}],
70+
)["jobSummaryList"][0]["jobId"]
71+
waiter = BatchWaitersHook(region_name=REGION)
72+
try:
73+
waiter.wait_for_job(job_id)
74+
return True
75+
except Exception:
76+
return False
77+
78+
79+
wait_for_job = PythonOperator(
80+
task_id="wait_for_job",
81+
python_callable=wait_for_batch_job,
82+
dag=dag,
83+
)
84+
85+
86+
parse_output = BashOperator(
87+
task_id="parse_output",
88+
bash_command="output: {{ ti.xcom_pull(task_ids='wait_for_job') }}",
89+
dag=dag,
90+
)
91+
92+
93+
train >> wait_for_job >> parse_output

0 commit comments

Comments
 (0)