Skip to content

Commit

Permalink
Adding pipeline name to each task/node
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiewior committed Jun 8, 2023
1 parent 115a60d commit ea85f61
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions kedro_snowflake/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ def _generate_mlflow_root_task_sql(self):
after_task=self._root_task_name,
)

def _sanitize_node_name(self, node_name: str) -> str:
return re.sub(r"\W", "_", node_name)
def _standardize_node_name(self, node_name: str) -> str:
sanity_node_name = re.sub(r"\W", "_", node_name)
return f"kedro_{self._get_pipeline_name_for_snowflake()}_{sanity_node_name}"

def _generate_snowflake_tasks_sql(
self,
Expand All @@ -156,13 +157,14 @@ def _generate_snowflake_tasks_sql(
) # <-- this one is not topological
for node in pipeline.nodes: # <-- this one is topological
after_tasks = [self._root_task_name] + [
self._sanitize_node_name(n.name) for n in node_dependencies[node]
f"{self._standardize_node_name(n.name)}"
for n in node_dependencies[node]
]
if self.mlflow_enabled:
after_tasks.append(self._mlflow_root_task_name)
sql_statements.append(
self._generate_task_sql(
self._sanitize_node_name(node.name),
self._standardize_node_name(node.name),
after_tasks,
self.pipeline_name,
[node.name],
Expand All @@ -181,23 +183,25 @@ def _generate_task_execute_sql(self):

@property
def _root_task_name(self):
root_task_name = f"kedro_snowflake_start_{self._get_pipeline_name_for_snowflake()}_task".upper()
root_task_name = (
f"kedro_{self._get_pipeline_name_for_snowflake()}__start_task".upper()
)
return root_task_name

@property
def _mlflow_root_task_name(self):
mlflow_root_task_name = f"kedro_snowflake_mlflow_start_{self._get_pipeline_name_for_snowflake()}_task".upper()
mlflow_root_task_name = (
f"kedro_{self._get_pipeline_name_for_snowflake()}_mlflow_start_task".upper()
)
return mlflow_root_task_name

@property
def _root_sproc_name(self):
return (
f"kedro_snowflake_start_{self._get_pipeline_name_for_snowflake()}".upper()
)
return f"kedro_{self._get_pipeline_name_for_snowflake()}_start".upper()

@property
def _mlfow_root_sproc_name(self):
return f"kedro_snowflake_start_mlflow_{self._get_pipeline_name_for_snowflake()}".upper()
return f"kedro_{self._get_pipeline_name_for_snowflake()}_start_mlflow".upper()

def generate(self) -> KedroSnowflakePipeline:
"""Generate a SnowflakePipeline object from a Kedro pipeline.
Expand Down Expand Up @@ -279,7 +283,7 @@ def generate(self) -> KedroSnowflakePipeline:
pipeline_sql_statements,
self._generate_task_execute_sql(),
self._root_task_name,
[self._sanitize_node_name(n.name) for n in pipeline.nodes],
[self._standardize_node_name(n.name) for n in pipeline.nodes],
)

def _generate_imports_for_sproc(self, dependencies_dir, snowflake_stage_name):
Expand Down

0 comments on commit ea85f61

Please sign in to comment.