Skip to content

Commit

Permalink
Adding mlflow_helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiewior committed Jun 7, 2023
1 parent c880eb5 commit 83dcd45
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json

import mlflow
import pandas as pd
import snowflake.snowpark as sp
from kedro_snowflake.config import SnowflakeMLflowConfig
from sklearn.ensemble import RandomForestRegressor


def _get_current_session():
return sp.context.get_active_session()


def _get_mlflow_config():
session = _get_current_session()
json_obj = json.loads(session.sql(
"call system$get_predecessor_return_value('KEDRO_SNOWFLAKE_MLFLOW_START_DEFAULT_TASK')").collect()[
0][0])
mlflow_config = SnowflakeMLflowConfig.parse_obj(json_obj)
return mlflow_config


def _get_run_id():
return _get_mlflow_config().run_id


def _get_experiment_id():
session = _get_current_session()
mlflow_config = _get_mlflow_config()
return eval(session.sql(
f"SELECT {mlflow_config.functions.experiment_get_by_name}('{mlflow_config.experiment_name}'):body.experiments[0].experiment_id").collect()[
0][0])


def get_snowflake_account():
session = _get_current_session()
return eval(session.get_current_account())


def get_current_warehouse():
session = _get_current_session()
return eval(session.get_current_warehouse())


def log_metric(key: str, val: float):
session = _get_current_session()
mlflow_config = _get_mlflow_config()
run_id = _get_run_id()
session.sql(f"select {mlflow_config.functions.run_log_metric}('{run_id}', '{key}', {val})").collect()


def log_parameter(key: str, val: float):
session = _get_current_session()
mlflow_config = _get_mlflow_config()
run_id = _get_run_id()
session.sql(f"select {mlflow_config.functions.run_log_parameter}('{run_id}', '{key}', '{val}')").collect()


def log_model(model: RandomForestRegressor, X_train: pd.DataFrame, y_train: pd.Series):
mlflow.set_tracking_uri("file:///tmp/mlruns")
mlflow_config = _get_mlflow_config()
local_exp_id = mlflow.create_experiment("temp")
exp_id = _get_experiment_id()
run_id = _get_run_id()
with mlflow.start_run(run_name='local', experiment_id=local_exp_id) as run:
local_run_id = run.info.run_id
mlflow.sklearn.autolog()
model.fit(X_train, y_train)
session = _get_current_session()
session.file.put(f"/tmp/mlruns/{local_exp_id}/{local_run_id}/artifacts/model/*",
f"@{mlflow_config.stage_name}/{exp_id}/{run_id}/artifacts/model/",
auto_compress=False)
return model
1 change: 1 addition & 0 deletions tests/conf/local/snowflake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ snowflake:
__default__: default
mlflow:
experiment_name: ~
stage_name: ~
functions:
experiment_get_by_name: mlflow_experiment_get_by_name
run_create: mlflow_run_create
Expand Down

0 comments on commit 83dcd45

Please sign in to comment.