-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
74 additions
and
0 deletions.
There are no files selected for viewing
73 changes: 73 additions & 0 deletions
73
...okiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/mlflow_helpers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import json | ||
|
||
import mlflow | ||
import pandas as pd | ||
import snowflake.snowpark as sp | ||
from kedro_snowflake.config import SnowflakeMLflowConfig | ||
from sklearn.ensemble import RandomForestRegressor | ||
|
||
|
||
def _get_current_session(): | ||
return sp.context.get_active_session() | ||
|
||
|
||
def _get_mlflow_config(): | ||
session = _get_current_session() | ||
json_obj = json.loads(session.sql( | ||
"call system$get_predecessor_return_value('KEDRO_SNOWFLAKE_MLFLOW_START_DEFAULT_TASK')").collect()[ | ||
0][0]) | ||
mlflow_config = SnowflakeMLflowConfig.parse_obj(json_obj) | ||
return mlflow_config | ||
|
||
|
||
def _get_run_id(): | ||
return _get_mlflow_config().run_id | ||
|
||
|
||
def _get_experiment_id(): | ||
session = _get_current_session() | ||
mlflow_config = _get_mlflow_config() | ||
return eval(session.sql( | ||
f"SELECT {mlflow_config.functions.experiment_get_by_name}('{mlflow_config.experiment_name}'):body.experiments[0].experiment_id").collect()[ | ||
0][0]) | ||
|
||
|
||
def get_snowflake_account(): | ||
session = _get_current_session() | ||
return eval(session.get_current_account()) | ||
|
||
|
||
def get_current_warehouse(): | ||
session = _get_current_session() | ||
return eval(session.get_current_warehouse()) | ||
|
||
|
||
def log_metric(key: str, val: float): | ||
session = _get_current_session() | ||
mlflow_config = _get_mlflow_config() | ||
run_id = _get_run_id() | ||
session.sql(f"select {mlflow_config.functions.run_log_metric}('{run_id}', '{key}', {val})").collect() | ||
|
||
|
||
def log_parameter(key: str, val: float): | ||
session = _get_current_session() | ||
mlflow_config = _get_mlflow_config() | ||
run_id = _get_run_id() | ||
session.sql(f"select {mlflow_config.functions.run_log_parameter}('{run_id}', '{key}', '{val}')").collect() | ||
|
||
|
||
def log_model(model: RandomForestRegressor, X_train: pd.DataFrame, y_train: pd.Series): | ||
mlflow.set_tracking_uri("file:///tmp/mlruns") | ||
mlflow_config = _get_mlflow_config() | ||
local_exp_id = mlflow.create_experiment("temp") | ||
exp_id = _get_experiment_id() | ||
run_id = _get_run_id() | ||
with mlflow.start_run(run_name='local', experiment_id=local_exp_id) as run: | ||
local_run_id = run.info.run_id | ||
mlflow.sklearn.autolog() | ||
model.fit(X_train, y_train) | ||
session = _get_current_session() | ||
session.file.put(f"/tmp/mlruns/{local_exp_id}/{local_run_id}/artifacts/model/*", | ||
f"@{mlflow_config.stage_name}/{exp_id}/{run_id}/artifacts/model/", | ||
auto_compress=False) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters