Skip to content

Commit

Permalink
refactor: other refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhammehra4 committed Nov 14, 2024
1 parent cf7caeb commit 060962b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 28 deletions.
4 changes: 2 additions & 2 deletions src/predictions/profiles_mlcorelib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def register_extensions(project):

project.register_model_type(AuditIdStitcherModel)

from .py_native.profiles_tutorial.model import TutorialModel
# from .py_native.profiles_tutorial.model import TutorialModel

project.register_model_type(TutorialModel)
# project.register_model_type(TutorialModel)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
self.client = client
self.schema = client.schema
self.db = client.db
self.io_handler = io_handler
self.io = io_handler
self.fast_mode = fast_mode

def get_qualified_name(self, table: str) -> str:
Expand All @@ -33,6 +33,7 @@ def upload_sample_data(
) -> Dict[str, str]:
new_table_names = {}
to_upload = True
# TODO: Need to update this
res = self.client.query_sql_with_result(
f"SHOW TABLES IN SCHEMA {self.db}.{self.schema}"
)
Expand All @@ -50,12 +51,12 @@ def upload_sample_data(
continue

if table_name.lower() in existing_tables:
print(f"Table {table_name} already exists.")
action = self.io_handler.get_user_input(
self.io.display_message(f"Table {table_name} already exists.")
action = self.io.get_user_input(
"Do you want to skip uploading again, so we can reuse the tables? (yes/no) (yes - skips upload, no - uploads again): "
)
if action == "yes":
print("Skipping upload of all csv files.")
self.io.display_message("Skipping upload of all csv files.")
to_upload = False
continue

Expand All @@ -64,7 +65,7 @@ def upload_sample_data(
)

df = pd.read_csv(os.path.join(sample_data_dir, filename))
print(
self.io.display_message(
f"Uploading file {filename} as table {table_name} with {df.shape[0]} rows and {df.shape[1]} columns"
)
self.client.write_df_to_table(
Expand All @@ -76,6 +77,7 @@ def upload_sample_data(
return new_table_names

def find_relevant_tables(self, new_table_names: Dict[str, str]) -> List[str]:
# TODO: Need to update this
res = self.client.query_sql_with_result(
f"SHOW TABLES IN SCHEMA {self.db}.{self.schema}"
)
Expand All @@ -87,6 +89,7 @@ def find_relevant_tables(self, new_table_names: Dict[str, str]) -> List[str]:

def get_columns(self, table: str) -> List[str]:
try:
# TODO: Need to update this
query = f"DESCRIBE TABLE {self.db}.{self.schema}.{table}"
result = self.client.query_sql_with_result(query)
columns = [row["name"] for _, row in result.iterrows()]
Expand Down Expand Up @@ -138,24 +141,24 @@ def map_columns_to_id_types(
shortlisted_columns[id_type] = matched_columns

# Display table context
print(f"\n{'-'*80}\n")
print(
self.io.display_message(f"\n{'-'*80}\n")
self.io.display_message(
f"The table `{table}` has the following columns, which look like id types:\n"
)

# Display shortlisted columns with sample data
for id_type, cols in shortlisted_columns.items():
for col in cols:
sample_data = self.get_sample_data(table, col)
print(f"id_type: {id_type}")
print(f"column: {col} (sample data: {sample_data})\n")
self.io.display_message(f"id_type: {id_type}")
self.io.display_message(f"column: {col} (sample data: {sample_data})\n")

# Display all available id_types
print(
self.io.display_message(
f"Following are all the id types defined earlier: \n\t{','.join(id_types)}"
)
shortlisted_id_types = ",".join(list(shortlisted_columns.keys()))
applicable_id_types_input = self.io_handler.get_user_input(
applicable_id_types_input = self.io.get_user_input(
f"Enter the comma-separated list of id_types applicable to the `{table}` table: \n>",
options=[shortlisted_id_types],
default=shortlisted_id_types,
Expand All @@ -177,24 +180,24 @@ def map_columns_to_id_types(
# Assert that all in shortlisted columns are in applicable_id_types
for id_type in shortlisted_columns:
if id_type not in applicable_id_types:
print(
self.io.display_message(
f"Please enter all id types applicable to the `{table}` table. The id type `{id_type}` is not found."
)
return None, "back"

if not applicable_id_types:
print(
self.io.display_message(
f"No valid id_types selected for `{table}` table. Skipping this table (it won't be part of id stitcher)"
)
return [], "next"

print(
self.io.display_message(
f"\nNow let's map different id_types in table `{table}` to a column (you can also use SQL string operations on these columns: ex: LOWER(EMAIL_ID), in case you want to use email as an id_type while also treating them as case insensitive):\n"
)
table_mappings = []
for id_type in applicable_id_types:
while True:
print(f"\nid type: {id_type}")
self.io.display_message(f"\nid type: {id_type}")
# Suggest columns based on regex matches
# suggested_cols = shortlisted_columns.get(id_type, [])
# if suggested_cols:
Expand All @@ -206,26 +209,30 @@ def map_columns_to_id_types(
default = id_type_mapping.get(id_type, id_type)
# else:
# default = None
user_input = self.io_handler.get_user_input(
user_input = self.io.get_user_input(
f"Enter the column(s) to map the id_type '{id_type}' in table `{table}`, or 'skip' to skip:\n> ",
default=default,
options=[default],
)
if user_input.lower() == "back":
return None, "back"
if user_input.lower() == "skip":
print(f"Skipping id_type '{id_type}' for table `{table}`")
self.io.display_message(
f"Skipping id_type '{id_type}' for table `{table}`"
)
break

selected_columns = [col.strip() for col in user_input.split(",")]
if not selected_columns:
print("No valid columns selected. Please try again.\n")
self.io.display_message(
"No valid columns selected. Please try again.\n"
)
continue
# Display selected columns with sample data for confirmation
print(f"Selected columns for id_type '{id_type}':")
self.io.display_message(f"Selected columns for id_type '{id_type}':")
for col in selected_columns:
sample_data = self.get_sample_data(table, col)
print(f"- {col} (sample data: {sample_data})")
self.io.display_message(f"- {col} (sample data: {sample_data})")

# confirm = self.io_handler.get_user_input("Is this correct? (yes/no): ", options=["yes", "no"])
# if confirm.lower() == 'yes':
Expand All @@ -237,19 +244,17 @@ def map_columns_to_id_types(
# else:
# logger.info("Let's try mapping again.\n")
if table_mappings:
print("Following is the summary of id types selected: \n")
self.io.display_message("Following is the summary of id types selected: \n")
summary = {"table": table, "ids": table_mappings}
yaml = YAML()
yaml.indent(mapping=2, sequence=4, offset=2)
yaml.preserve_quotes = True
yaml.width = 4096 # Prevent line wrapping
yaml.dump(summary, sys.stdout)
print("\n")
self.io_handler.get_user_input(
f"The above is the inputs yaml for table `{table}`"
)
self.io.display_message("\n")
self.io.get_user_input(f"The above is the inputs yaml for table `{table}`")
else:
self.io_handler.get_user_input(
self.io.get_user_input(
"No id_type mappings were selected for this table.\n"
)
return table_mappings, "next"
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def second_run(
self.io.display_message(dense_edges.head(20).to_string())
self.io.display_multiline_message(messages.EXPLAIN_BAD_ANNOYMOUS_IDS)

# TODO: fix this query
query_investigate_bad_anons = f"""
WITH edge_table as (
SELECT
Expand Down

0 comments on commit 060962b

Please sign in to comment.