diff --git a/splink/internals/clustering.py b/splink/internals/clustering.py index bf216a3ced..b1c9271234 100644 --- a/splink/internals/clustering.py +++ b/splink/internals/clustering.py @@ -223,22 +223,23 @@ def _calculate_stable_clusters_at_new_threshold( return sqls +def _threshold_to_str(x): + if x == 0.0: + return "0_0" + elif x == 1.0: + return "1_0" + else: + return f"{x:.8f}".rstrip("0").replace(".", "_") + + def _generate_detailed_cluster_comparison_sql( all_results: dict[float, SplinkDataFrame], unique_id_col: str = "unique_id", ) -> str: thresholds = sorted(all_results.keys()) - def threshold_to_str(x): - if x == 0.0: - return "0_0" - elif x == 1.0: - return "1_0" - else: - return f"{x:.8f}".rstrip("0").replace(".", "_") - select_columns = [f"t0.{unique_id_col}"] + [ - f"t{i}.cluster_id AS cluster_{threshold_to_str(threshold)}" + f"t{i}.cluster_id AS cluster_{_threshold_to_str(threshold)}" for i, threshold in enumerate(thresholds) ] @@ -258,24 +259,41 @@ def threshold_to_str(x): return sql -def _generate_distinct_cluster_count_sql( +def _get_cluster_stats_sql(cc: SplinkDataFrame) -> list[dict[str, str]]: + sqls = [] + cluster_sizes_sql = f""" + SELECT + cluster_id, + COUNT(*) AS cluster_size + FROM {cc.templated_name} + GROUP BY cluster_id + """ + sqls.append( + {"sql": cluster_sizes_sql, "output_table_name": "__splink__cluster_sizes"} + ) + + cluster_stats_sql = """ + SELECT + COUNT(*) AS num_clusters, + MAX(cluster_size) AS max_cluster_size, + AVG(cluster_size) AS avg_cluster_size + FROM __splink__cluster_sizes + """ + sqls.append( + {"sql": cluster_stats_sql, "output_table_name": "__splink__cluster_stats"} + ) + + return sqls + + +def _generate_cluster_summary_stats_sql( all_results: dict[float, SplinkDataFrame], ) -> str: thresholds = sorted(all_results.keys()) - def threshold_to_str(x): - if x == 0.0: - return "0_0" - elif x == 1.0: - return "1_0" - else: - return f"{x:.8f}".rstrip("0").replace(".", "_") - select_statements = [ f""" - SELECT - cast({threshold} as float) AS threshold, - distinct_clusters + SELECT cast({threshold} as float) as threshold, * FROM {all_results[threshold].physical_name} """ for threshold in thresholds @@ -294,9 +312,9 @@ def cluster_pairwise_predictions_at_multiple_thresholds( match_probability_thresholds: list[float], edge_id_column_name_left: Optional[str] = None, edge_id_column_name_right: Optional[str] = None, - output_number_of_distinct_clusters_only: bool = False, + output_cluster_summary_stats: bool = False, ) -> SplinkDataFrame: - """Clusters the pairwise match predictions at multiple thresholds using + """Clusters the pairwise match predictions at multiple thresholds using the connected components graph clustering algorithm. This function efficiently computes clusters for multiple thresholds by starting @@ -318,13 +336,14 @@ def cluster_pairwise_predictions_at_multiple_thresholds( left edge IDs. If not provided, assumed to be f"{node_id_column_name}_l" edge_id_column_name_right (Optional[str]): The name of the column containing right edge IDs. If not provided, assumed to be f"{node_id_column_name}_r" - output_number_of_distinct_clusters_only (bool): If True, only output the number - of distinct clusters for each threshold instead of full cluster information + output_cluster_summary_stats (bool): If True, output summary statistics + for each threshold instead of full cluster information Returns: SplinkDataFrame: A SplinkDataFrame containing cluster information for all - thresholds. If output_number_of_distinct_clusters_only is True, it contains - the count of distinct clusters for each threshold. + thresholds. If output_cluster_summary_stats is True, it contains summary + statistics (number of clusters, max cluster size, avg cluster size) for + each threshold. Examples: ```python @@ -413,15 +432,12 @@ def cluster_pairwise_predictions_at_multiple_thresholds( threshold_match_probability=initial_threshold, ) - if output_number_of_distinct_clusters_only: + if output_cluster_summary_stats: pipeline = CTEPipeline([cc]) - sql = f""" - select count(distinct cluster_id) as distinct_clusters - from {cc.templated_name} - """ - pipeline.enqueue_sql(sql, "__splink__distinct_clusters_at_threshold") - cc_distinct = db_api.sql_pipeline_to_splink_dataframe(pipeline) - all_results[initial_threshold] = cc_distinct + sqls = _get_cluster_stats_sql(cc) + pipeline.enqueue_list_of_sqls(sqls) + cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) + all_results[initial_threshold] = cc_summary else: all_results[initial_threshold] = cc @@ -498,21 +514,18 @@ def cluster_pairwise_predictions_at_multiple_thresholds( stable_clusters.drop_table_from_database_and_remove_from_cache() marginal_new_clusters.drop_table_from_database_and_remove_from_cache() - if output_number_of_distinct_clusters_only: + if output_cluster_summary_stats: pipeline = CTEPipeline([cc]) - sql = f""" - select count(distinct cluster_id) as distinct_clusters - from {cc.templated_name} - """ - pipeline.enqueue_sql(sql, "__splink__distinct_clusters_at_threshold") - cc_distinct = db_api.sql_pipeline_to_splink_dataframe(pipeline) - all_results[new_threshold] = cc_distinct + sqls = _get_cluster_stats_sql(cc) + pipeline.enqueue_list_of_sqls(sqls) + cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) + all_results[new_threshold] = cc_summary previous_cc.drop_table_from_database_and_remove_from_cache() else: all_results[new_threshold] = cc - if output_number_of_distinct_clusters_only: - sql = _generate_distinct_cluster_count_sql(all_results) + if output_cluster_summary_stats: + sql = _generate_cluster_summary_stats_sql(all_results) else: sql = _generate_detailed_cluster_comparison_sql( all_results,