-
Notifications
You must be signed in to change notification settings - Fork 83
/
Copy pathdag_ml_pipeline_amazon_video_reviews.py
205 lines (171 loc) · 5.9 KB
/
dag_ml_pipeline_amazon_video_reviews.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from __future__ import print_function
import json
import requests
from datetime import datetime
# airflow operators
import airflow
from airflow.models import DAG
from airflow.utils.trigger_rule import TriggerRule
from airflow.operators.python_operator import BranchPythonOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator
# airflow sagemaker operators
from airflow.contrib.operators.sagemaker_training_operator \
import SageMakerTrainingOperator
from airflow.contrib.operators.sagemaker_tuning_operator \
import SageMakerTuningOperator
from airflow.contrib.operators.sagemaker_transform_operator \
import SageMakerTransformOperator
from airflow.contrib.hooks.aws_hook import AwsHook
# sagemaker sdk
import boto3
import sagemaker
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.estimator import Estimator
from sagemaker.tuner import HyperparameterTuner
# airflow sagemaker configuration
from sagemaker.workflow.airflow import training_config
from sagemaker.workflow.airflow import tuning_config
from sagemaker.workflow.airflow import transform_config_from_estimator
# ml workflow specific
from pipeline import prepare, preprocess
import config as cfg
# =============================================================================
# functions
# =============================================================================
def is_hpo_enabled():
"""check if hyper-parameter optimization is enabled in the config
"""
hpo_enabled = False
if "job_level" in config and \
"run_hyperparameter_opt" in config["job_level"]:
run_hpo_config = config["job_level"]["run_hyperparameter_opt"]
if run_hpo_config.lower() == "yes":
hpo_enabled = True
return hpo_enabled
def get_sagemaker_role_arn(role_name, region_name):
iam = boto3.client('iam', region_name=region_name)
response = iam.get_role(RoleName=role_name)
return response["Role"]["Arn"]
# =============================================================================
# setting up training, tuning and transform configuration
# =============================================================================
# read config file
config = cfg.config
# set configuration for tasks
hook = AwsHook(aws_conn_id='airflow-sagemaker')
region = config["job_level"]["region_name"]
sess = hook.get_session(region_name=region)
role = get_sagemaker_role_arn(
config["train_model"]["sagemaker_role"],
sess.region_name)
container = get_image_uri(sess.region_name, 'factorization-machines')
hpo_enabled = is_hpo_enabled()
# create estimator
fm_estimator = Estimator(
image_name=container,
role=role,
sagemaker_session=sagemaker.session.Session(sess),
**config["train_model"]["estimator_config"]
)
# train_config specifies SageMaker training configuration
train_config = training_config(
estimator=fm_estimator,
inputs=config["train_model"]["inputs"])
# create tuner
fm_tuner = HyperparameterTuner(
estimator=fm_estimator,
**config["tune_model"]["tuner_config"]
)
# create tuning config
tuner_config = tuning_config(
tuner=fm_tuner,
inputs=config["tune_model"]["inputs"])
# create transform config
transform_config = transform_config_from_estimator(
estimator=fm_estimator,
task_id="model_tuning" if hpo_enabled else "model_training",
task_type="tuning" if hpo_enabled else "training",
**config["batch_transform"]["transform_config"]
)
# =============================================================================
# define airflow DAG and tasks
# =============================================================================
# define airflow DAG
args = {
'owner': 'airflow',
'start_date': airflow.utils.dates.days_ago(2)
}
dag = DAG(
dag_id='sagemaker-ml-pipeline',
default_args=args,
schedule_interval=None,
concurrency=1,
max_active_runs=1,
user_defined_filters={'tojson': lambda s: json.JSONEncoder().encode(s)}
)
# set the tasks in the DAG
# dummy operator
init = DummyOperator(
task_id='start',
dag=dag
)
# preprocess the data
preprocess_task = PythonOperator(
task_id='preprocessing',
dag=dag,
provide_context=False,
python_callable=preprocess.preprocess,
op_kwargs=config["preprocess_data"])
# prepare the data for training
prepare_task = PythonOperator(
task_id='preparing',
dag=dag,
provide_context=False,
python_callable=prepare.prepare,
op_kwargs=config["prepare_data"]
)
branching = BranchPythonOperator(
task_id='branching',
dag=dag,
python_callable=lambda: "model_tuning" if hpo_enabled else "model_training")
# launch sagemaker training job and wait until it completes
train_model_task = SageMakerTrainingOperator(
task_id='model_training',
dag=dag,
config=train_config,
aws_conn_id='airflow-sagemaker',
wait_for_completion=True,
check_interval=30
)
# launch sagemaker hyperparameter job and wait until it completes
tune_model_task = SageMakerTuningOperator(
task_id='model_tuning',
dag=dag,
config=tuner_config,
aws_conn_id='airflow-sagemaker',
wait_for_completion=True,
check_interval=30
)
# launch sagemaker batch transform job and wait until it completes
batch_transform_task = SageMakerTransformOperator(
task_id='predicting',
dag=dag,
config=transform_config,
aws_conn_id='airflow-sagemaker',
wait_for_completion=True,
check_interval=30,
trigger_rule=TriggerRule.ONE_SUCCESS
)
cleanup_task = DummyOperator(
task_id='cleaning_up',
dag=dag)
# set the dependencies between tasks
init.set_downstream(preprocess_task)
preprocess_task.set_downstream(prepare_task)
prepare_task.set_downstream(branching)
branching.set_downstream(tune_model_task)
branching.set_downstream(train_model_task)
tune_model_task.set_downstream(batch_transform_task)
train_model_task.set_downstream(batch_transform_task)
batch_transform_task.set_downstream(cleanup_task)