Skip to content

Commit

Permalink
Merge pull request #1677 from moj-analytical-services/cluster_metrics
Browse files Browse the repository at this point in the history
Cluster metrics
  • Loading branch information
zslade authored Nov 8, 2023
2 parents f505f6f + 7570b63 commit 442810b
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 0 deletions.
65 changes: 65 additions & 0 deletions splink/cluster_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from splink.input_column import InputColumn


def _size_density_sql(
df_predict, df_clustered, threshold_match_probability, _unique_id_col
):
"""Generates sql for computing cluster size and density at a given threshold.
Args:
df_predict (SplinkDataFrame): The results of `linker.predict()`
df_clustered (SplinkDataFrame): The outputs of
`linker.cluster_pairwise_predictions_at_threshold()`
threshold_match_probability (float): Filter the pairwise match
predictions to include only pairwise comparisons with a
match_probability above this threshold.
_unique_id_col (string): name of unique id column in settings dict
Returns:
sql string for computing cluster size and density
"""

# Get physical table names from Splink dataframes
edges_table = df_predict.physical_name
clusters_table = df_clustered.physical_name

input_col = InputColumn(_unique_id_col)
unique_id_col_l = input_col.name_l()

sqls = []
sql = f"""
SELECT
{unique_id_col_l},
COUNT(*) AS count_edges
FROM {edges_table}
WHERE match_probability >= {threshold_match_probability}
GROUP BY {unique_id_col_l}
"""

sql = {"sql": sql, "output_table_name": "__splink__count_edges"}
sqls.append(sql)

sql = f"""
SELECT
c.cluster_id,
count(*) AS n_nodes,
sum(e.count_edges) AS n_edges
FROM {clusters_table} AS c
LEFT JOIN __splink__count_edges e ON c.{_unique_id_col} = e.{unique_id_col_l}
GROUP BY c.cluster_id
"""
sql = {"sql": sql, "output_table_name": "__splink__counts_per_cluster"}
sqls.append(sql)

sql = """
SELECT
cluster_id,
n_nodes,
n_edges,
(n_edges * 2)/(n_nodes * (n_nodes-1)) AS density
FROM __splink__counts_per_cluster
"""
sql = {"sql": sql, "output_table_name": "__splink__cluster_metrics_clusters"}
sqls.append(sql)

return sqls
40 changes: 40 additions & 0 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
unlinkables_chart,
waterfall_chart,
)
from .cluster_metrics import _size_density_sql
from .cluster_studio import render_splink_cluster_studio_html
from .comparison import Comparison
from .comparison_level import ComparisonLevel
Expand Down Expand Up @@ -2082,6 +2083,45 @@ def cluster_pairwise_predictions_at_threshold(

return cc

def _compute_cluster_metrics(
self,
df_predict: SplinkDataFrame,
df_clustered: SplinkDataFrame,
threshold_match_probability: float = None,
):
"""Generates a table containing cluster metrics and returns a Splink dataframe
Args:
df_predict (SplinkDataFrame): The results of `linker.predict()`
df_clustered (SplinkDataFrame): The outputs of
`linker.cluster_pairwise_predictions_at_threshold()`
threshold_match_probability (float): Filter the pairwise match predictions
to include only pairwise comparisons with a match_probability above this
threshold.
Returns:
SplinkDataFrame: A SplinkDataFrame containing cluster IDs and selected
cluster metrics
"""

# Get unique row id column name from settings
unique_id_col = self._settings_obj._unique_id_column_name

sqls = _size_density_sql(
df_predict,
df_clustered,
threshold_match_probability,
_unique_id_col=unique_id_col,
)

for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])

df_cluster_metrics = self._execute_sql_pipeline()

return df_cluster_metrics

def profile_columns(
self, column_expressions: str | list[str] = None, top_n=10, bottom_n=10
):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_cluster_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pandas as pd
from pandas.testing import assert_frame_equal

from splink.duckdb.linker import DuckDBLinker

# Dummy df
person_ids = [i + 1 for i in range(5)]
df = pd.DataFrame({"person_id": person_ids})

# Dummy edges df
edges_data = [
# cluster A edges
{"person_id_l": 1, "person_id_r": 2, "match_probability": 0.99},
{"person_id_l": 1, "person_id_r": 3, "match_probability": 0.99},
# cluster B edge
{"person_id_l": 4, "person_id_r": 5, "match_probability": 0.99},
# edges not in relevant clusters
{"person_id_l": 10, "person_id_r": 11, "match_probability": 0.99},
{"person_id_l": 12, "person_id_r": 12, "match_probability": 0.95},
]
edges = pd.DataFrame(edges_data)

# Dummy clusters df
cluster_ids = ["A", "A", "A", "B", "B"]
clusters_data = {"cluster_id": cluster_ids, "person_id": person_ids}
clusters = pd.DataFrame(clusters_data)

# Expected dataframe
expected_data = [
{"cluster_id": "A", "n_nodes": 3, "n_edges": 2.0, "density": 2 / 3},
{"cluster_id": "B", "n_nodes": 2, "n_edges": 1.0, "density": 1.0},
]
df_expected = pd.DataFrame(expected_data)


def test_size_density():
# Linker with basic settings
settings = {"link_type": "dedupe_only", "unique_id_column_name": "person_id"}
linker = DuckDBLinker(df, settings)

# Register as Splink dataframes
df_predict = linker.register_table(edges, "df_predict", overwrite=True)
df_clustered = linker.register_table(clusters, "df_clustered", overwrite=True)

df_cluster_metrics = linker._compute_cluster_metrics(
df_predict, df_clustered, threshold_match_probability=0.99
)
df_cluster_metrics = df_cluster_metrics.as_pandas_dataframe()

assert_frame_equal(df_cluster_metrics, df_expected)

0 comments on commit 442810b

Please sign in to comment.