Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ds-platform-utils"
version = "0.1.0"
version = "0.2.0"
description = "Utility library for Pattern Data Science."
readme = "README.md"
authors = [
Expand Down
16 changes: 14 additions & 2 deletions src/ds_platform_utils/metaflow/pandas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
Expand All @@ -17,7 +18,7 @@
# )
from ds_platform_utils.metaflow._consts import NON_PROD_SCHEMA, PROD_SCHEMA
from ds_platform_utils.metaflow.get_snowflake_connection import _debug_print_query, get_snowflake_connection
from ds_platform_utils.metaflow.write_audit_publish import _make_snowflake_table_url
from ds_platform_utils.metaflow.write_audit_publish import _make_snowflake_table_url, get_select_dev_query_tags

TWarehouse = Literal[
"OUTERBOUNDS_DATA_SCIENCE_XS_WH",
Expand Down Expand Up @@ -96,12 +97,18 @@ def publish_pandas( # noqa: PLR0913 (too many arguments)
current.card.append(Table.from_dataframe(df.head()))

conn: SnowflakeConnection = get_snowflake_connection(use_utc)

# set warehouse
if warehouse is not None:
with conn.cursor() as cur:
cur.execute(f"USE WAREHOUSE {warehouse};")

# set query tag for cost tracking in select.dev
# REASON: because write_pandas() doesn't allow modifying the SQL query to add SQL comments in it directly,
# so we set a session query tag instead.
tags = get_select_dev_query_tags()
query_tag_str = json.dumps(tags)
cur.execute(f"ALTER SESSION SET QUERY_TAG = '{query_tag_str}';")

# https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/snowpark/api/snowflake.snowpark.Session.write_pandas
write_pandas(
conn=conn,
Expand Down Expand Up @@ -154,7 +161,12 @@ def query_pandas_from_snowflake(
substitute_map_into_string,
)

# adding query tags comment in query for cost tracking in select.dev
tags = get_select_dev_query_tags()
query_comment_str = f"\n\n/* {json.dumps(tags)} */"
query = get_query_from_string_or_fpath(query)
query = query + query_comment_str

if "{{schema}}" in query or "{{ schema }}" in query:
schema = PROD_SCHEMA if current.is_production else NON_PROD_SCHEMA
query = substitute_map_into_string(query, {"schema": schema})
Expand Down
98 changes: 97 additions & 1 deletion src/ds_platform_utils/metaflow/write_audit_publish.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import warnings
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
Expand All @@ -24,6 +26,91 @@
]


def get_select_dev_query_tags() -> Dict[str, str]:
"""Return tags for the current Metaflow flow run.

These tags are used for cost tracking in select.dev.
See the select.dev docs on custom workload tags:
https://select.dev/docs/reference/integrations/custom-workloads#example-query-tag

What the main tags mean and why we set them this way:

"app": a broad category that groups queries by domain. We set app to the value of ds.domain
that we get from current tags of the flow, so queries are attributed to the right domain (for example, "Operations").

"workload_id": identifies the specific project or sub-unit inside that domain.
We set workload_id to the value of ds.project that we get from current tags of
the flow so select.dev can attribute costs to the exact project (for example, "out-of-stock").

For more granular attribution we have other tags:

"pipeline": the flow name

"step_name": the step within the flow

"run_id": the unique id of the flow run

"user": the username of the user who triggered the flow run (or argo-work

"namespace": the namespace of the flow run

"team": the team name, hardcoded as "data-science" for all flows

**Note: all other tags are arbitrary. Add any extra key/value pairs that help you trace and group queries for cost reporting.**
"""
fetched_tags = current.tags
required_tags_are_present = any(tag.startswith("ds.project") for tag in fetched_tags) and any(
tag.startswith("ds.domain") for tag in fetched_tags
) # Checking presence of both required Metaflow user tags in current tags of the flow
if not required_tags_are_present:
warnings.warn(
dedent("""
Warning: ds-platform-utils attempted to add query tags to a Snowflake query
for cost tracking in select.dev, but no tags were found on this Metaflow flow.
Please add them with --tag, for example:

uv run <flow_name>_flow.py \\
--environment=fast-bakery \\
--package-suffixes='.csv,.sql,.json,.toml,.yaml,.yml,.txt' \\
--with card \\
argo-workflows create \\
--tag "ds.domain:operations" \\
--tag "ds.project:regional-forecast"

Note: in the monorepo, these tags are applied automatically in CI and when using
the standard poe commands for running flows.
"""),
stacklevel=2,
)

def extract(prefix: str, default: str = "unknown") -> str:
for tag in fetched_tags:
if tag.startswith(prefix + ":"):
return tag.split(":", 1)[1]
return default

# most of these will be unknown if no tags are set on the flow
# (most likely for the flow runs which are triggered manually locally)
return {
"app": extract(
"ds.domain"
), # first tag after 'app:', is the domain of the flow, fetched from current tags of the flow
"workload_id": extract(
"ds.project"
), # second tag after 'workload_id:', is the project of the flow which it belongs to
"flow_name": current.flow_name, # name of the metaflow flow
"project": current.project_name, # Project name from the @project decorator, lets us
# identify the flow’s project without relying on user tags (added via --tag).
"step_name": current.step_name, # name of the current step
"run_id": current.run_id, # run_id: unique id of the current run
"user": current.username, # username of user who triggered the run (argo-workflows if its a deployed flow)
"domain": extract("ds.domain"), # business unit (domain) of the flow, same as app
"namespace": current.namespace, # namespace of the flow
"perimeter": "PROD" if current.is_production else "Default", # perimeter of the flow
"team": "data-science", # team name, hardcoded as data-science
}


def publish( # noqa: PLR0913
table_name: str,
query: Union[str, Path],
Expand All @@ -33,10 +120,19 @@ def publish( # noqa: PLR0913
use_utc: bool = True,
) -> None:
"""Publish a table using write-audit-publish pattern with Metaflow's Snowflake connection."""
from ds_platform_utils._snowflake.write_audit_publish import write_audit_publish
from ds_platform_utils._snowflake.write_audit_publish import (
get_query_from_string_or_fpath,
write_audit_publish,
)

conn = get_snowflake_connection(use_utc=use_utc)

# adding query tags comment in query for cost tracking in select.dev
tags = get_select_dev_query_tags()
query_comment_str = f"\n\n/* {json.dumps(tags)} */"
query = get_query_from_string_or_fpath(query)
query = query + query_comment_str

with conn.cursor() as cur:
if warehouse is not None:
cur.execute(f"USE WAREHOUSE {warehouse}")
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.