diff --git a/splink/cluster_metrics.py b/splink/cluster_metrics.py new file mode 100644 index 0000000000..d15977245a --- /dev/null +++ b/splink/cluster_metrics.py @@ -0,0 +1,65 @@ +from splink.input_column import InputColumn + + +def _size_density_sql( + df_predict, df_clustered, threshold_match_probability, _unique_id_col +): + """Generates sql for computing cluster size and density at a given threshold. + + Args: + df_predict (SplinkDataFrame): The results of `linker.predict()` + df_clustered (SplinkDataFrame): The outputs of + `linker.cluster_pairwise_predictions_at_threshold()` + threshold_match_probability (float): Filter the pairwise match + predictions to include only pairwise comparisons with a + match_probability above this threshold. + _unique_id_col (string): name of unique id column in settings dict + + Returns: + sql string for computing cluster size and density + """ + + # Get physical table names from Splink dataframes + edges_table = df_predict.physical_name + clusters_table = df_clustered.physical_name + + input_col = InputColumn(_unique_id_col) + unique_id_col_l = input_col.name_l() + + sqls = [] + sql = f""" + SELECT + {unique_id_col_l}, + COUNT(*) AS count_edges + FROM {edges_table} + WHERE match_probability >= {threshold_match_probability} + GROUP BY {unique_id_col_l} + """ + + sql = {"sql": sql, "output_table_name": "__splink__count_edges"} + sqls.append(sql) + + sql = f""" + SELECT + c.cluster_id, + count(*) AS n_nodes, + sum(e.count_edges) AS n_edges + FROM {clusters_table} AS c + LEFT JOIN __splink__count_edges e ON c.{_unique_id_col} = e.{unique_id_col_l} + GROUP BY c.cluster_id + """ + sql = {"sql": sql, "output_table_name": "__splink__counts_per_cluster"} + sqls.append(sql) + + sql = """ + SELECT + cluster_id, + n_nodes, + n_edges, + (n_edges * 2)/(n_nodes * (n_nodes-1)) AS density + FROM __splink__counts_per_cluster + """ + sql = {"sql": sql, "output_table_name": "__splink__cluster_metrics_clusters"} + sqls.append(sql) + + return sqls diff --git a/splink/linker.py b/splink/linker.py index c4a96b4a87..ae0116f48e 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -50,6 +50,7 @@ unlinkables_chart, waterfall_chart, ) +from .cluster_metrics import _size_density_sql from .cluster_studio import render_splink_cluster_studio_html from .comparison import Comparison from .comparison_level import ComparisonLevel @@ -2082,6 +2083,45 @@ def cluster_pairwise_predictions_at_threshold( return cc + def _compute_cluster_metrics( + self, + df_predict: SplinkDataFrame, + df_clustered: SplinkDataFrame, + threshold_match_probability: float = None, + ): + """Generates a table containing cluster metrics and returns a Splink dataframe + + Args: + df_predict (SplinkDataFrame): The results of `linker.predict()` + df_clustered (SplinkDataFrame): The outputs of + `linker.cluster_pairwise_predictions_at_threshold()` + threshold_match_probability (float): Filter the pairwise match predictions + to include only pairwise comparisons with a match_probability above this + threshold. + + Returns: + SplinkDataFrame: A SplinkDataFrame containing cluster IDs and selected + cluster metrics + + """ + + # Get unique row id column name from settings + unique_id_col = self._settings_obj._unique_id_column_name + + sqls = _size_density_sql( + df_predict, + df_clustered, + threshold_match_probability, + _unique_id_col=unique_id_col, + ) + + for sql in sqls: + self._enqueue_sql(sql["sql"], sql["output_table_name"]) + + df_cluster_metrics = self._execute_sql_pipeline() + + return df_cluster_metrics + def profile_columns( self, column_expressions: str | list[str] = None, top_n=10, bottom_n=10 ): diff --git a/tests/test_cluster_metrics.py b/tests/test_cluster_metrics.py new file mode 100644 index 0000000000..da2fe0f5fb --- /dev/null +++ b/tests/test_cluster_metrics.py @@ -0,0 +1,50 @@ +import pandas as pd +from pandas.testing import assert_frame_equal + +from splink.duckdb.linker import DuckDBLinker + +# Dummy df +person_ids = [i + 1 for i in range(5)] +df = pd.DataFrame({"person_id": person_ids}) + +# Dummy edges df +edges_data = [ + # cluster A edges + {"person_id_l": 1, "person_id_r": 2, "match_probability": 0.99}, + {"person_id_l": 1, "person_id_r": 3, "match_probability": 0.99}, + # cluster B edge + {"person_id_l": 4, "person_id_r": 5, "match_probability": 0.99}, + # edges not in relevant clusters + {"person_id_l": 10, "person_id_r": 11, "match_probability": 0.99}, + {"person_id_l": 12, "person_id_r": 12, "match_probability": 0.95}, +] +edges = pd.DataFrame(edges_data) + +# Dummy clusters df +cluster_ids = ["A", "A", "A", "B", "B"] +clusters_data = {"cluster_id": cluster_ids, "person_id": person_ids} +clusters = pd.DataFrame(clusters_data) + +# Expected dataframe +expected_data = [ + {"cluster_id": "A", "n_nodes": 3, "n_edges": 2.0, "density": 2 / 3}, + {"cluster_id": "B", "n_nodes": 2, "n_edges": 1.0, "density": 1.0}, +] +df_expected = pd.DataFrame(expected_data) + + +def test_size_density(): + # Linker with basic settings + settings = {"link_type": "dedupe_only", "unique_id_column_name": "person_id"} + linker = DuckDBLinker(df, settings) + + # Register as Splink dataframes + df_predict = linker.register_table(edges, "df_predict", overwrite=True) + df_clustered = linker.register_table(clusters, "df_clustered", overwrite=True) + + df_cluster_metrics = linker._compute_cluster_metrics( + df_predict, df_clustered, threshold_match_probability=0.99 + ) + df_cluster_metrics = df_cluster_metrics.as_pandas_dataframe() + + assert_frame_equal(df_cluster_metrics, df_expected)