-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from Aura-healthcare/orchestration_train
Orchestration train
- Loading branch information
Showing
25 changed files
with
27,687 additions
and
923 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
import sys | ||
from datetime import datetime as dt | ||
from sklearn.ensemble import RandomForestClassifier | ||
import datetime | ||
import xgboost as xgb | ||
import numpy as np | ||
|
||
PROJECT_FOLDER = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
DATA_FOLDER = os.path.join(PROJECT_FOLDER, 'data') | ||
|
||
ML_DATASET_OUTPUT_FOLDER = "/opt/airflow/output" | ||
AIRFLOW_PREFIX_TO_DATA = '/opt/airflow/data/' | ||
MLRUNS_DIR = '/mlruns' | ||
|
||
TRAIN_DATA = os.path.join(AIRFLOW_PREFIX_TO_DATA, "train/df_ml_train.csv") | ||
TEST_DATA = os.path.join(AIRFLOW_PREFIX_TO_DATA , "test/df_ml_test.csv") | ||
FEATURE_TRAIN_PATH= os.path.join(ML_DATASET_OUTPUT_FOLDER, "ml_train.csv") | ||
FEATURE_TEST_PATH= os.path.join(ML_DATASET_OUTPUT_FOLDER, "ml_test.csv") | ||
|
||
COL_TO_DROP = ['interval_index', 'interval_start_time', 'set'] | ||
|
||
START_DATE = dt(2021, 8, 1) | ||
CONCURRENCY = 4 | ||
SCHEDULE_INTERVAL = datetime.timedelta(hours=2) | ||
DEFAULT_ARGS = {'owner': 'airflow'} | ||
|
||
TRACKING_URI = 'http://mlflow:5000' | ||
|
||
MODEL_PARAM = { | ||
'model': xgb.XGBClassifier(), | ||
'grid_parameters': { | ||
'nthread':[4], | ||
'learning_rate': [0.1, 0.01, 0.05], | ||
'max_depth': np.arange(3, 5, 2), | ||
'scale_pos_weight':[1], | ||
'n_estimators': np.arange(15, 25, 2), | ||
'missing':[-999]} | ||
} | ||
|
||
MODELS_PARAM = { | ||
'xgboost': { | ||
'model': xgb.XGBClassifier(), | ||
'grid_parameters': { | ||
'nthread':[4], | ||
'learning_rate': [0.1, 0.01, 0.05], | ||
'max_depth': np.arange(3, 5, 2), | ||
'scale_pos_weight':[1], | ||
'n_estimators': np.arange(15, 25, 2), | ||
'missing':[-999] | ||
} | ||
}, | ||
'random_forest': { | ||
'model': RandomForestClassifier(), | ||
'grid_parameters': { | ||
'min_samples_leaf': np.arange(1, 5, 1), | ||
'max_depth': np.arange(1, 7, 1), | ||
'max_features': ['auto'], | ||
'n_estimators': np.arange(10, 20, 2)} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import os | ||
import sys | ||
from datetime import datetime, timedelta, datetime | ||
|
||
from airflow.decorators import dag, task | ||
from airflow.utils.dates import days_ago | ||
|
||
sys.path.append('.') | ||
from dags.config import (DEFAULT_ARGS, START_DATE, CONCURRENCY, SCHEDULE_INTERVAL) | ||
|
||
|
||
@dag(default_args=DEFAULT_ARGS, | ||
start_date=START_DATE, | ||
schedule_interval=timedelta(minutes=2), | ||
concurrency=CONCURRENCY) | ||
def predict(): | ||
@task | ||
def prepare_features_with_io_task() -> str: | ||
pass | ||
|
||
@task | ||
def predict_with_io_task(feature_path: str) -> None: | ||
pass | ||
|
||
feature_path = prepare_features_with_io_task() | ||
predict_with_io_task(feature_path) | ||
|
||
predict_dag = predict() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from src.usecase.data_processing.prepare_features import prepare_features_with_io | ||
from src.usecase.train_model import (train_pipeline_with_io) | ||
from dags.config import ( | ||
DEFAULT_ARGS, | ||
START_DATE, | ||
CONCURRENCY, | ||
SCHEDULE_INTERVAL, | ||
MODELS_PARAM, | ||
MLRUNS_DIR, | ||
TEST_DATA, | ||
TRACKING_URI, | ||
TRAIN_DATA, | ||
FEATURE_TRAIN_PATH, | ||
FEATURE_TEST_PATH, | ||
COL_TO_DROP) | ||
import sys | ||
|
||
from airflow.decorators import dag, task | ||
|
||
@dag(default_args=DEFAULT_ARGS, | ||
start_date=START_DATE, | ||
schedule_interval=SCHEDULE_INTERVAL, | ||
catchup=False, | ||
concurrency=CONCURRENCY) | ||
def train_pipeline(): | ||
|
||
@task | ||
def prepare_features_task( | ||
dataset_path: str, | ||
col_to_drop: list, | ||
feature_path: str) -> str: | ||
|
||
prepare_features_with_io( | ||
dataset_path=dataset_path, | ||
col_to_drop=col_to_drop, | ||
features_path=feature_path) | ||
|
||
return feature_path | ||
|
||
@task | ||
def train_model_task( | ||
feature_tain_path: str, | ||
feature_test_path: str, | ||
tracking_uri: str = TRACKING_URI, | ||
model_param: dict = MODELS_PARAM['xgboost'], | ||
mlruns_dir: str = MLRUNS_DIR) -> None: | ||
|
||
train_pipeline_with_io(feature_tain_path, feature_test_path, | ||
tracking_uri=tracking_uri, model_param=model_param, mlruns_dir=mlruns_dir) | ||
|
||
# Orchestration | ||
features_train_path = FEATURE_TRAIN_PATH | ||
features_test_path = FEATURE_TEST_PATH | ||
|
||
ml_train_path = prepare_features_task( | ||
dataset_path=TRAIN_DATA, | ||
col_to_drop=COL_TO_DROP, | ||
feature_path=features_train_path) | ||
|
||
ml_test_path = prepare_features_task( | ||
dataset_path=TEST_DATA, | ||
col_to_drop=COL_TO_DROP, | ||
feature_path=features_test_path) | ||
|
||
train_model_task(feature_tain_path=ml_train_path, feature_test_path=ml_test_path, tracking_uri=TRACKING_URI, | ||
model_param=MODELS_PARAM['xgboost'], mlruns_dir=MLRUNS_DIR) | ||
|
||
|
||
train_pipeline_dag = train_pipeline() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../DATA/DetecTeppe-2022-04-08/ml_dataset_2022_04_08 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
FROM python:3.7-slim-buster | ||
|
||
RUN pip install mlflow==1.19.0 psycopg2-binary==2.9.1 | ||
RUN pip install mlflow==1.28 psycopg2-binary==2.9.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.