Skip to content

Commit

Permalink
Passing mlflow config
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiewior committed Jun 7, 2023
1 parent 5d65a8b commit 7435c6e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
1 change: 1 addition & 0 deletions kedro_snowflake/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class MLflowFunctionsConfig(BaseModel):
class SnowflakeMLflowConfig(BaseModel):
experiment_name: Optional[str]
functions: MLflowFunctionsConfig
run_id: Optional[str]


class SnowflakeConfig(BaseModel):
Expand Down
13 changes: 10 additions & 3 deletions kedro_snowflake/generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
import re
Expand All @@ -7,15 +8,16 @@
from typing import Any, Dict, List, Optional

from kedro.pipeline import Pipeline
from snowflake.snowpark.functions import sproc
from snowflake.snowpark.session import Session

from kedro_snowflake.config import KedroSnowflakeConfig
from kedro_snowflake.pipeline import KedroSnowflakePipeline
from kedro_snowflake.utils import (
get_module_path,
zip_dependencies,
zstd_folder,
)
from snowflake.snowpark.functions import sproc
from snowflake.snowpark.session import Session

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -339,14 +341,19 @@ def _construct_kedro_snowflake_mlflow_root_sproc(self, stage_location: str):
f"SELECT {experiment_get_by_name_func}('{experiment_name}'):body.experiments[0].experiment_id"
).collect()[0][0]
)
mlflow_config = self.config.snowflake.mlflow.dict()

def mlflow_start_run(session: Session) -> str:
run_id = eval(
session.sql(
f"SELECT {run_create_func}({experiment_id}):body.run.info.run_id"
).collect()[0][0]
)
session.sql(f"call system$set_return_value('{run_id}');").collect()
mlflow_config["run_id"] = run_id
mlflow_config_json = json.dumps(mlflow_config)
session.sql(
f"call system$set_return_value('{mlflow_config_json}');"
).collect()
return run_id

return sproc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ snowflake:
- more-itertools
- openpyxl
- backoff
- pydantic
# Optionally provide mapping for user-friendly pipeline names
pipeline_name_mapping:
__default__: default
Expand Down

0 comments on commit 7435c6e

Please sign in to comment.