Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cluster multiple add stats #2453

Merged
merged 2 commits into from
Oct 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 58 additions & 45 deletions splink/internals/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,22 +223,23 @@ def _calculate_stable_clusters_at_new_threshold(
return sqls


def _threshold_to_str(x):
if x == 0.0:
return "0_0"
elif x == 1.0:
return "1_0"
else:
return f"{x:.8f}".rstrip("0").replace(".", "_")


def _generate_detailed_cluster_comparison_sql(
all_results: dict[float, SplinkDataFrame],
unique_id_col: str = "unique_id",
) -> str:
thresholds = sorted(all_results.keys())

def threshold_to_str(x):
if x == 0.0:
return "0_0"
elif x == 1.0:
return "1_0"
else:
return f"{x:.8f}".rstrip("0").replace(".", "_")

select_columns = [f"t0.{unique_id_col}"] + [
f"t{i}.cluster_id AS cluster_{threshold_to_str(threshold)}"
f"t{i}.cluster_id AS cluster_{_threshold_to_str(threshold)}"
for i, threshold in enumerate(thresholds)
]

Expand All @@ -258,24 +259,41 @@ def threshold_to_str(x):
return sql


def _generate_distinct_cluster_count_sql(
def _get_cluster_stats_sql(cc: SplinkDataFrame) -> list[dict[str, str]]:
sqls = []
cluster_sizes_sql = f"""
SELECT
cluster_id,
COUNT(*) AS cluster_size
FROM {cc.templated_name}
GROUP BY cluster_id
"""
sqls.append(
{"sql": cluster_sizes_sql, "output_table_name": "__splink__cluster_sizes"}
)

cluster_stats_sql = """
SELECT
COUNT(*) AS num_clusters,
MAX(cluster_size) AS max_cluster_size,
AVG(cluster_size) AS avg_cluster_size
FROM __splink__cluster_sizes
"""
sqls.append(
{"sql": cluster_stats_sql, "output_table_name": "__splink__cluster_stats"}
)

return sqls


def _generate_cluster_summary_stats_sql(
all_results: dict[float, SplinkDataFrame],
) -> str:
thresholds = sorted(all_results.keys())

def threshold_to_str(x):
if x == 0.0:
return "0_0"
elif x == 1.0:
return "1_0"
else:
return f"{x:.8f}".rstrip("0").replace(".", "_")

select_statements = [
f"""
SELECT
cast({threshold} as float) AS threshold,
distinct_clusters
SELECT cast({threshold} as float) as threshold, *
FROM {all_results[threshold].physical_name}
"""
for threshold in thresholds
Expand All @@ -294,9 +312,9 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
match_probability_thresholds: list[float],
edge_id_column_name_left: Optional[str] = None,
edge_id_column_name_right: Optional[str] = None,
output_number_of_distinct_clusters_only: bool = False,
output_cluster_summary_stats: bool = False,
) -> SplinkDataFrame:
"""Clusters the pairwise match predictions at multiple thresholds using
"""Clusters the pairwise match predictions at multiple thresholds using
the connected components graph clustering algorithm.

This function efficiently computes clusters for multiple thresholds by starting
Expand All @@ -318,13 +336,14 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
left edge IDs. If not provided, assumed to be f"{node_id_column_name}_l"
edge_id_column_name_right (Optional[str]): The name of the column containing
right edge IDs. If not provided, assumed to be f"{node_id_column_name}_r"
output_number_of_distinct_clusters_only (bool): If True, only output the number
of distinct clusters for each threshold instead of full cluster information
output_cluster_summary_stats (bool): If True, output summary statistics
for each threshold instead of full cluster information

Returns:
SplinkDataFrame: A SplinkDataFrame containing cluster information for all
thresholds. If output_number_of_distinct_clusters_only is True, it contains
the count of distinct clusters for each threshold.
thresholds. If output_cluster_summary_stats is True, it contains summary
statistics (number of clusters, max cluster size, avg cluster size) for
each threshold.

Examples:
```python
Expand Down Expand Up @@ -413,15 +432,12 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
threshold_match_probability=initial_threshold,
)

if output_number_of_distinct_clusters_only:
if output_cluster_summary_stats:
pipeline = CTEPipeline([cc])
sql = f"""
select count(distinct cluster_id) as distinct_clusters
from {cc.templated_name}
"""
pipeline.enqueue_sql(sql, "__splink__distinct_clusters_at_threshold")
cc_distinct = db_api.sql_pipeline_to_splink_dataframe(pipeline)
all_results[initial_threshold] = cc_distinct
sqls = _get_cluster_stats_sql(cc)
pipeline.enqueue_list_of_sqls(sqls)
cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline)
all_results[initial_threshold] = cc_summary
else:
all_results[initial_threshold] = cc

Expand Down Expand Up @@ -498,21 +514,18 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
stable_clusters.drop_table_from_database_and_remove_from_cache()
marginal_new_clusters.drop_table_from_database_and_remove_from_cache()

if output_number_of_distinct_clusters_only:
if output_cluster_summary_stats:
pipeline = CTEPipeline([cc])
sql = f"""
select count(distinct cluster_id) as distinct_clusters
from {cc.templated_name}
"""
pipeline.enqueue_sql(sql, "__splink__distinct_clusters_at_threshold")
cc_distinct = db_api.sql_pipeline_to_splink_dataframe(pipeline)
all_results[new_threshold] = cc_distinct
sqls = _get_cluster_stats_sql(cc)
pipeline.enqueue_list_of_sqls(sqls)
cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline)
all_results[new_threshold] = cc_summary
previous_cc.drop_table_from_database_and_remove_from_cache()
else:
all_results[new_threshold] = cc

if output_number_of_distinct_clusters_only:
sql = _generate_distinct_cluster_count_sql(all_results)
if output_cluster_summary_stats:
sql = _generate_cluster_summary_stats_sql(all_results)
else:
sql = _generate_detailed_cluster_comparison_sql(
all_results,
Expand Down
Loading