Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchrecipes/paved_path/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,6 @@ Note that it will output a URL like "aws_batch://torchx/..." that is used to tra
torchx status "aws_batch://torchx/..."
```

## Pipelines
As your applications getting complicated, you can make them as pipelines, manage and monitor them by frameworks like Airflow, Kubeflow, etc.
* [Airflow example](https://github.com/facebookresearch/recipes/tree/main/torchrecipes/paved_path/airflow)
22 changes: 22 additions & 0 deletions torchrecipes/paved_path/airflow/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Airflow example
1. Install and start an Airflow server
```bash
./setup.sh
```
> **_NOTE:_**: Airflow UI can be accessed at http://0.0.0.0:8080 (replace the address with your EC2 instance address for public access). Learn more about airflow from [Quick Start](https://airflow.apache.org/docs/apache-airflow/stable/start/local.html)

2. Create a dag config. See the example in `train_charnn.py`

3. Set `dag_folder` to folder containing the dag config in `~/airflow/airflow.cfg`. Such that Airflow can discover your dag configs.

3. Run a task instance
```bash
airflow tasks run train_charnn train 2022-08-01
```
> **_NOTE:_**: the instance can be monitored in UI: http://0.0.0.0:8080/taskinstance/list

4. Backfill the dag over 2 days
```bash
airflow dags backfill train_charnn --start-date 2022-08-01 --end-date 2022-08-02
```
> **_NOTE:_**: the dag runs can be monitored in UI: http://0.0.0.0:8080/dagrun/list/
20 changes: 20 additions & 0 deletions torchrecipes/paved_path/airflow/setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

install_airflow=true
start_local_airflow=true

if [ "$install_airflow" = true ]
then
pip3 install --upgrade pip
sudo apt install libffi-dev
pip3 install setuptools-rust
pip3 install "apache-airflow[celery]==2.3.0" --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.3.0/constraints-3.8.txt"
pip3 install apache-airflow-providers-amazon
pip3 install boto3
fi

# https://airflow.apache.org/docs/apache-airflow/stable/start/local.html
if [ "$start_local_airflow" = true ]
then
airflow standalone
fi
98 changes: 98 additions & 0 deletions torchrecipes/paved_path/airflow/train_charnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from datetime import datetime, timedelta

import boto3.session

from airflow import DAG
from airflow.operators.bash import BashOperator
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.hooks.batch_waiters import BatchWaitersHook

REGION = "us-west-2"
JOB_QUEUE = "torchx-gpu"
ECR_URL = os.environ["ECR_URL"]


default_args = {
"depends_on_past": False,
"email": ["[email protected]"],
"email_on_failure": False,
"email_on_retry": False,
"retries": 0,
"retry_delay": timedelta(minutes=5),
}


dag = DAG(
"train_charnn",
default_args=default_args,
description="A DAG to train charnn in AWS Batch",
schedule_interval="@daily",
catchup=False,
start_date=datetime(2022, 8, 1),
tags=["aws_batch"],
)


# This example uses torchx CLI with BashOperator.
# We can also use PythonOperator to achive it.
train = BashOperator(
task_id="train",
bash_command=f"""AWS_DEFAULT_REGION=$REGION \
torchx run --workspace '' -s aws_batch \
-cfg queue={JOB_QUEUE},image_repo={ECR_URL}/charnn dist.ddp \
--script charnn/main.py --image {ECR_URL}/charnn:latest \
--cpu 8 --gpu 2 -j 1x2 --memMB 20480 2>&1 \
| grep -Eo aws_batch://torchx/{JOB_QUEUE}:main-[a-z0-9]+""",
env={
"REGION": REGION,
"JOB_QUEUE": JOB_QUEUE,
"ECR_URL": ECR_URL,
},
append_env=True,
dag=dag,
do_xcom_push=True,
)


def wait_for_batch_job(**context) -> bool:
session = boto3.session.Session()
client = session.client("batch", region_name=REGION)
# XComs are a mechanism that let Tasks talk to each other
# Learn more from https://airflow.apache.org/docs/apache-airflow/stable/concepts/xcoms.html
job = context["ti"].xcom_pull(task_ids="train")
job_desc = job.split("/")[-1]
queue_name, job_name = job_desc.split(":")
job_id = client.list_jobs(
jobQueue=queue_name,
filters=[{"name": "JOB_NAME", "values": [job_name]}],
)["jobSummaryList"][0]["jobId"]
waiter = BatchWaitersHook(region_name=REGION)
try:
waiter.wait_for_job(job_id)
return True
except Exception:
return False


wait_for_job = PythonOperator(
task_id="wait_for_job",
python_callable=wait_for_batch_job,
dag=dag,
)


parse_output = BashOperator(
task_id="parse_output",
bash_command="echo {{ ti.xcom_pull(task_ids='wait_for_job') }}",
dag=dag,
)


train >> wait_for_job >> parse_output