Skip to content

Commit

Permalink
feat(ingest/snowflake): support email_as_user_identifier for queries …
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate authored Dec 27, 2024
1 parent 172736a commit 3ca8d09
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,20 @@ class SnowflakeIdentifierConfig(
description="Whether to convert dataset urns to lowercase.",
)


class SnowflakeUsageConfig(BaseUsageConfig):
email_domain: Optional[str] = pydantic.Field(
default=None,
description="Email domain of your organization so users can be displayed on UI appropriately.",
)

email_as_user_identifier: bool = Field(
default=True,
description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is "
"provided, generates email addresses for snowflake users with unset emails, based on their "
"username.",
)


class SnowflakeUsageConfig(BaseUsageConfig):
apply_view_usage_to_tables: bool = pydantic.Field(
default=False,
description="Whether to apply view's usage to its base tables. If set to True, usage is applied to base tables only.",
Expand Down Expand Up @@ -267,13 +275,6 @@ class SnowflakeV2Config(
" Map of share name -> details of share.",
)

email_as_user_identifier: bool = Field(
default=True,
description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is "
"provided, generates email addresses for snowflake users with unset emails, based on their "
"username.",
)

include_assertion_results: bool = Field(
default=False,
description="Whether to ingest assertion run results for assertions created using Datahub"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@

logger = logging.getLogger(__name__)

# Define a type alias
UserName = str
UserEmail = str
UsersMapping = Dict[UserName, UserEmail]


class SnowflakeQueriesExtractorConfig(ConfigModel):
# TODO: Support stateful ingestion for the time windows.
Expand Down Expand Up @@ -114,11 +119,13 @@ class SnowflakeQueriesSourceConfig(
class SnowflakeQueriesExtractorReport(Report):
copy_history_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
query_log_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
users_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)

audit_log_load_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
sql_aggregator: Optional[SqlAggregatorReport] = None

num_ddl_queries_dropped: int = 0
num_users: int = 0


@dataclass
Expand Down Expand Up @@ -225,6 +232,9 @@ def is_allowed_table(self, name: str) -> bool:
def get_workunits_internal(
self,
) -> Iterable[MetadataWorkUnit]:
with self.report.users_fetch_timer:
users = self.fetch_users()

# TODO: Add some logic to check if the cached audit log is stale or not.
audit_log_file = self.local_temp_path / "audit_log.sqlite"
use_cached_audit_log = audit_log_file.exists()
Expand All @@ -248,7 +258,7 @@ def get_workunits_internal(
queries.append(entry)

with self.report.query_log_fetch_timer:
for entry in self.fetch_query_log():
for entry in self.fetch_query_log(users):
queries.append(entry)

with self.report.audit_log_load_timer:
Expand All @@ -263,6 +273,25 @@ def get_workunits_internal(
shared_connection.close()
audit_log_file.unlink(missing_ok=True)

def fetch_users(self) -> UsersMapping:
users: UsersMapping = dict()
with self.structured_reporter.report_exc("Error fetching users from Snowflake"):
logger.info("Fetching users from Snowflake")
query = SnowflakeQuery.get_all_users()
resp = self.connection.query(query)

for row in resp:
try:
users[row["NAME"]] = row["EMAIL"]
self.report.num_users += 1
except Exception as e:
self.structured_reporter.warning(
"Error parsing user row",
context=f"{row}",
exc=e,
)
return users

def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:
# Derived from _populate_external_lineage_from_copy_history.

Expand Down Expand Up @@ -298,7 +327,7 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:
yield result

def fetch_query_log(
self,
self, users: UsersMapping
) -> Iterable[Union[PreparsedQuery, TableRename, TableSwap]]:
query_log_query = _build_enriched_query_log_query(
start_time=self.config.window.start_time,
Expand All @@ -319,7 +348,7 @@ def fetch_query_log(

assert isinstance(row, dict)
try:
entry = self._parse_audit_log_row(row)
entry = self._parse_audit_log_row(row, users)
except Exception as e:
self.structured_reporter.warning(
"Error parsing query log row",
Expand All @@ -331,7 +360,7 @@ def fetch_query_log(
yield entry

def _parse_audit_log_row(
self, row: Dict[str, Any]
self, row: Dict[str, Any], users: UsersMapping
) -> Optional[Union[TableRename, TableSwap, PreparsedQuery]]:
json_fields = {
"DIRECT_OBJECTS_ACCESSED",
Expand Down Expand Up @@ -430,9 +459,11 @@ def _parse_audit_log_row(
)
)

# TODO: Fetch email addresses from Snowflake to map user -> email
# TODO: Support email_domain fallback for generating user urns.
user = CorpUserUrn(self.identifiers.snowflake_identifier(res["user_name"]))
user = CorpUserUrn(
self.identifiers.get_user_identifier(
res["user_name"], users.get(res["user_name"])
)
)

timestamp: datetime = res["query_start_time"]
timestamp = timestamp.astimezone(timezone.utc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -947,4 +947,8 @@ def dmf_assertion_results(start_time_millis: int, end_time_millis: int) -> str:
AND METRIC_NAME ilike '{pattern}' escape '{escape_pattern}'
ORDER BY MEASUREMENT_TIME ASC;
"""
"""

@staticmethod
def get_all_users() -> str:
return """SELECT name as "NAME", email as "EMAIL" FROM SNOWFLAKE.ACCOUNT_USAGE.USERS"""
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,9 @@ def _map_user_counts(
filtered_user_counts.append(
DatasetUserUsageCounts(
user=make_user_urn(
self.get_user_identifier(
self.identifiers.get_user_identifier(
user_count["user_name"],
user_email,
self.config.email_as_user_identifier,
)
),
count=user_count["total"],
Expand Down Expand Up @@ -453,9 +452,7 @@ def _get_operation_aspect_work_unit(
reported_time: int = int(time.time() * 1000)
last_updated_timestamp: int = int(start_time.timestamp() * 1000)
user_urn = make_user_urn(
self.get_user_identifier(
user_name, user_email, self.config.email_as_user_identifier
)
self.identifiers.get_user_identifier(user_name, user_email)
)

# NOTE: In earlier `snowflake-usage` connector this was base_objects_accessed, which is incorrect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,28 @@ def get_quoted_identifier_for_schema(db_name, schema_name):
def get_quoted_identifier_for_table(db_name, schema_name, table_name):
return f'"{db_name}"."{schema_name}"."{table_name}"'

# Note - decide how to construct user urns.
# Historically urns were created using part before @ from user's email.
# Users without email were skipped from both user entries as well as aggregates.
# However email is not mandatory field in snowflake user, user_name is always present.
def get_user_identifier(
self,
user_name: str,
user_email: Optional[str],
) -> str:
if user_email:
return self.snowflake_identifier(
user_email
if self.identifier_config.email_as_user_identifier is True
else user_email.split("@")[0]
)
return self.snowflake_identifier(
f"{user_name}@{self.identifier_config.email_domain}"
if self.identifier_config.email_as_user_identifier is True
and self.identifier_config.email_domain is not None
else user_name
)


class SnowflakeCommonMixin(SnowflakeStructuredReportMixin):
platform = "snowflake"
Expand All @@ -315,24 +337,6 @@ def structured_reporter(self) -> SourceReport:
def identifiers(self) -> SnowflakeIdentifierBuilder:
return SnowflakeIdentifierBuilder(self.config, self.report)

# Note - decide how to construct user urns.
# Historically urns were created using part before @ from user's email.
# Users without email were skipped from both user entries as well as aggregates.
# However email is not mandatory field in snowflake user, user_name is always present.
def get_user_identifier(
self,
user_name: str,
user_email: Optional[str],
email_as_user_identifier: bool,
) -> str:
if user_email:
return self.identifiers.snowflake_identifier(
user_email
if email_as_user_identifier is True
else user_email.split("@")[0]
)
return self.identifiers.snowflake_identifier(user_name)

# TODO: Revisit this after stateful ingestion can commit checkpoint
# for failures that do not affect the checkpoint
# TODO: Add additional parameters to match the signature of the .warning and .failure methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,58 @@ def test_source_close_cleans_tmp(snowflake_connect, tmp_path):
# This closes QueriesExtractor which in turn closes SqlParsingAggregator
source.close()
assert len(os.listdir(tmp_path)) == 0


@patch("snowflake.connector.connect")
def test_user_identifiers_email_as_identifier(snowflake_connect, tmp_path):
source = SnowflakeQueriesSource.create(
{
"connection": {
"account_id": "ABC12345.ap-south-1.aws",
"username": "TST_USR",
"password": "TST_PWD",
},
"email_as_user_identifier": True,
"email_domain": "example.com",
},
PipelineContext("run-id"),
)
assert (
source.identifiers.get_user_identifier("username", "[email protected]")
== "[email protected]"
)
assert (
source.identifiers.get_user_identifier("username", None)
== "[email protected]"
)

# We'd do best effort to use email as identifier, but would keep username as is,
# if email can't be formed.
source.identifiers.identifier_config.email_domain = None

assert (
source.identifiers.get_user_identifier("username", "[email protected]")
== "[email protected]"
)

assert source.identifiers.get_user_identifier("username", None) == "username"


@patch("snowflake.connector.connect")
def test_user_identifiers_username_as_identifier(snowflake_connect, tmp_path):
source = SnowflakeQueriesSource.create(
{
"connection": {
"account_id": "ABC12345.ap-south-1.aws",
"username": "TST_USR",
"password": "TST_PWD",
},
"email_as_user_identifier": False,
},
PipelineContext("run-id"),
)
assert (
source.identifiers.get_user_identifier("username", "[email protected]")
== "username"
)
assert source.identifiers.get_user_identifier("username", None) == "username"

0 comments on commit 3ca8d09

Please sign in to comment.