Skip to content

Commit

Permalink
Init MLflow support
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiewior committed Apr 25, 2023
1 parent 9a72181 commit 16c3c80
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 3 deletions.
25 changes: 25 additions & 0 deletions kedro_snowflake/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def check_credentials(cls, values):
class DependenciesConfig(BaseModel):
packages: List[str] = [
"snowflake-snowpark-python",
"mlflow",
"cachetools",
"pluggy",
"PyYAML==6.0",
Expand Down Expand Up @@ -80,9 +81,22 @@ class SnowflakeRuntimeConfig(BaseModel):
pipeline_name_mapping: Optional[Dict[str, str]] = {"__default__": "default"}


class MLflowFunctionsConfig(BaseModel):
experiment_get_by_name: str = "mlflow_experiment_get_by_name"
run_create: str = "mlflow_run_create"
run_log_metric: str = "mlflow_run_log_metric"
run_log_parameter: str = "mlflow_run_log_parameter"


class SnowflakeMLflowConfig(BaseModel):
experiment_name: Optional[str]
functions: MLflowFunctionsConfig


class SnowflakeConfig(BaseModel):
connection: SnowflakeConnectionConfig
runtime: SnowflakeRuntimeConfig
mlflow: SnowflakeMLflowConfig


class KedroSnowflakeConfig(BaseModel):
Expand Down Expand Up @@ -136,6 +150,7 @@ class KedroSnowflakeConfig(BaseModel):
# https://repo.anaconda.com/pkgs/snowflake/
packages:
- snowflake-snowpark-python
- mlflow
- cachetools
- pluggy
- PyYAML==6.0
Expand All @@ -155,6 +170,16 @@ class KedroSnowflakeConfig(BaseModel):
# Optionally provide mapping for user-friendly pipeline names
pipeline_name_mapping:
__default__: default
# EXPERIMENTAL: Either MLflow experiment name to enable MLflow tracking
# or leave empty
mlflow:
experiment_name: ~
# Snowflake external functions needed for calling MLflow instance
functions:
experiment_get_by_name: mlflow_experiment_get_by_name
run_create: mlflow_run_create
run_log_metric: mlflow_run_log_metric
run_log_parameter: mlflow_run_log_parameter
""".strip()

# This auto-validates the template above during import
Expand Down
62 changes: 59 additions & 3 deletions kedro_snowflake/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
self.config = config
self.pipeline_name = pipeline_name
self.extra_env = extra_env
self.mlflow_enabled = mlflow_enabled = True if self.config.snowflake.mlflow.experiment_name else False

def _get_pipeline_name_for_snowflake(self):
return (self.config.snowflake.runtime.pipeline_name_mapping or {}).get(
Expand All @@ -77,7 +78,7 @@ def _generate_task_sql(
warehouse=self.connection_parameters["warehouse"],
after_tasks=",".join(after_tasks),
task_body=self.TASK_BODY_TEMPLATE.format(
root_task_name=self._root_task_name,
root_task_name= self._root_task_name if not self.mlflow_enabled else self._mlflow_root_task_name,
environment=self.kedro_environment,
sproc_name=self.SPROC_NAME,
pipeline_name=pipeline_name,
Expand All @@ -100,6 +101,20 @@ def _generate_root_task_sql(self):
schedule=self.config.snowflake.runtime.schedule,
)

def _generate_mlflow_root_task_sql(self):
return """
create or replace task {task_name}
warehouse = '{warehouse}'
after {after_task}
as
call {root_sproc}();
""".strip().format(
task_name=self._mlflow_root_task_name,
warehouse=self.connection_parameters["warehouse"],
root_sproc=self._mlfow_root_sproc_name,
after_task=self._root_task_name
)

def _sanitize_node_name(self, node_name: str) -> str:
return re.sub(r"\W", "_", node_name)

Expand All @@ -108,12 +123,14 @@ def _generate_snowflake_tasks_sql(
pipeline: Pipeline,
) -> List[str]:
sql_statements = [self._generate_root_task_sql()]
if self.mlflow_enabled:
sql_statements.append(self._generate_mlflow_root_task_sql())

node_dependencies = (
pipeline.node_dependencies
) # <-- this one is not topological
for node in pipeline.nodes: # <-- this one is topological
after_tasks = [self._root_task_name] + [
after_tasks = [self._root_task_name] if not self.mlflow_enabled else [self._mlflow_root_task_name ] + [
self._sanitize_node_name(n.name) for n in node_dependencies[node]
]
sql_statements.append(
Expand All @@ -139,12 +156,23 @@ def _root_task_name(self):
root_task_name = f"kedro_snowflake_start_{self._get_pipeline_name_for_snowflake()}_task".upper()
return root_task_name

@property
def _mlflow_root_task_name(self):
mlflow_root_task_name = f"kedro_snowflake_mlflow_start_{self._get_pipeline_name_for_snowflake()}_task".upper()
return mlflow_root_task_name

@property
def _root_sproc_name(self):
return (
f"kedro_snowflake_start_{self._get_pipeline_name_for_snowflake()}".upper()
)

@property
def _mlfow_root_sproc_name(self):
return (
f"kedro_snowflake_start_mlflow_{self._get_pipeline_name_for_snowflake()}".upper()
)

def generate(self) -> KedroSnowflakePipeline:
"""Generate a SnowflakePipeline object from a Kedro pipeline.
It can be used to run the pipeline or just to get the SQL statements.
Expand Down Expand Up @@ -201,6 +229,11 @@ def generate(self) -> KedroSnowflakePipeline:
snowflake_stage_name
)

if self.mlflow_enabled:
mlflow_root_sproc = self._construct_kedro_snowflake_mlflow_root_sproc(
snowflake_stage_name
)

logger.info("Creating Kedro Snowflake Sproc")
snowflake_sproc = self._construct_kedro_snowflake_sproc(
imports=self._generate_imports_for_sproc(
Expand Down Expand Up @@ -269,10 +302,33 @@ def _drop_and_recreate_stages(self, *stages):
def snowflake_session(self):
return Session.builder.configs(self.connection_parameters).create()


def _construct_kedro_snowflake_mlflow_root_sproc(self, stage_location: str):
experiment_name = self.config.snowflake.mlflow.experiment_name
experiment_get_by_name_func = self.config.snowflake.mlflow.functions.experiment_get_by_name
run_create_func = self.config.snowflake.mlflow.functions.run_create
experiment_id = eval(self.snowflake_session.sql(f"SELECT {experiment_get_by_name_func}('{experiment_name}'):body.experiments[0].experiment_id").collect()[0][0])

def mlflow_start_run(session: Session) -> str:
run_id = eval(session.sql(f"SELECT {run_create_func}({experiment_id}):body.run.info.run_id").collect()[0][0])
session.sql(f"call system$set_return_value('{run_id}');").collect()
return run_id

return sproc(
func=mlflow_start_run,
name=self._mlfow_root_sproc_name,
is_permanent=True,
replace=True,
stage_location=stage_location,
packages=["snowflake-snowpark-python"],
execute_as="caller",
session=self.snowflake_session,
)

def _construct_kedro_snowflake_root_sproc(self, stage_location: str):

def kedro_start_run(session: Session) -> str:
from uuid import uuid4

run_id = uuid4().hex
session.sql(f"call system$set_return_value('{run_id}');").collect()
return run_id
Expand Down

0 comments on commit 16c3c80

Please sign in to comment.