Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One to one clustering #2578

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
159 changes: 158 additions & 1 deletion splink/internals/linker_components/clustering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from splink.internals.connected_components import (
solve_connected_components,
Expand All @@ -14,6 +14,9 @@
from splink.internals.misc import (
threshold_args_to_match_prob,
)
from splink.internals.one_to_one_clustering import (
one_to_one_clustering,
)
from splink.internals.pipeline import CTEPipeline
from splink.internals.splink_dataframe import SplinkDataFrame
from splink.internals.unique_id_concat import (
Expand Down Expand Up @@ -177,6 +180,160 @@ def cluster_pairwise_predictions_at_threshold(

return df_clustered_with_input_data

def cluster_using_single_best_links(
self,
df_predict: SplinkDataFrame,
duplicate_free_datasets: List[str],
threshold_match_probability: Optional[float] = None,
threshold_match_weight: Optional[float] = None,
) -> SplinkDataFrame:
"""
Clusters the pairwise match predictions that result from
`linker.inference.predict()` into groups of connected records using a single
best links method that restricts the clusters to have at most one record from
each source dataset in the `duplicate_free_datasets` list.

This method will include a record into a cluster if it is mutually the best
match for the record and for the cluster, and if adding the record will not
violate the criteria of having at most one record from each of the
`duplicate_free_datasets`.

Args:
df_predict (SplinkDataFrame): The results of `linker.predict()`
duplicate_free_datasets: (List[str]): The source datasets which should be
treated as having no duplicates. Clusters will not form with more than
one record from each of these datasets. This can be a subset of all of
the source datasets in the input data.
threshold_match_probability (float, optional): Pairwise comparisons with a
`match_probability` at or above this threshold are matched
threshold_match_weight (float, optional): Pairwise comparisons with a
`match_weight` at or above this threshold are matched. Only one of
threshold_match_probability or threshold_match_weight should be provided

Returns:
SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered
into groups based on the desired match threshold and the source datasets
for which duplicates are not allowed.

Examples:
```python
df_predict = linker.inference.predict(threshold_match_probability=0.5)
df_clustered = linker.clustering.cluster_pairwise_predictions_at_threshold(
df_predict,
duplicate_free_datasets=["A", "B"],
threshold_match_probability=0.95
)
```
"""
linker = self._linker
db_api = linker._db_api

pipeline = CTEPipeline()

enqueue_df_concat(linker, pipeline)

uid_cols = linker._settings_obj.column_info_settings.unique_id_input_columns
uid_concat_edges_l = _composite_unique_id_from_edges_sql(uid_cols, "l")
uid_concat_edges_r = _composite_unique_id_from_edges_sql(uid_cols, "r")
uid_concat_nodes = _composite_unique_id_from_nodes_sql(uid_cols, None)

source_dataset_column_name = (
linker._settings_obj.column_info_settings.source_dataset_column_name
)

sql = f"""
select
{uid_concat_nodes} as node_id,
{source_dataset_column_name} as source_dataset
from __splink__df_concat
"""
pipeline.enqueue_sql(sql, "__splink__df_nodes_with_composite_ids")

nodes_with_composite_ids = db_api.sql_pipeline_to_splink_dataframe(pipeline)

has_match_prob_col = "match_probability" in [
c.unquote().name for c in df_predict.columns
]

threshold_match_probability = threshold_args_to_match_prob(
threshold_match_probability, threshold_match_weight
)

if not has_match_prob_col and threshold_match_probability is not None:
raise ValueError(
"df_predict must have a column called 'match_probability' if "
"threshold_match_probability is provided"
)

match_p_expr = ""
match_p_select_expr = ""
if threshold_match_probability is not None:
match_p_expr = f"where match_probability >= {threshold_match_probability}"
match_p_select_expr = ", match_probability"

pipeline = CTEPipeline([df_predict])

# Templated name must be used here because it could be the output
# of a deterministic link i.e. the templated name is not know for sure
sql = f"""
select
{uid_concat_edges_l} as node_id_l,
{uid_concat_edges_r} as node_id_r
{match_p_select_expr}
from {df_predict.templated_name}
{match_p_expr}
"""
pipeline.enqueue_sql(sql, "__splink__df_edges_from_predict")

edges_table_with_composite_ids = db_api.sql_pipeline_to_splink_dataframe(
pipeline
)

oo = one_to_one_clustering(
nodes_table=nodes_with_composite_ids,
edges_table=edges_table_with_composite_ids,
node_id_column_name="node_id",
source_dataset_column_name="source_dataset",
edge_id_column_name_left="node_id_l",
edge_id_column_name_right="node_id_r",
duplicate_free_datasets=duplicate_free_datasets,
db_api=db_api,
threshold_match_probability=threshold_match_probability,
)

edges_table_with_composite_ids.drop_table_from_database_and_remove_from_cache()
nodes_with_composite_ids.drop_table_from_database_and_remove_from_cache()
pipeline = CTEPipeline([oo])

enqueue_df_concat(linker, pipeline)

columns = concat_table_column_names(self._linker)
# don't want to include salting column in output if present
columns_without_salt = filter(lambda x: x != "__splink_salt", columns)

select_columns_sql = ", ".join(columns_without_salt)

sql = f"""
select
oo.cluster_id,
{select_columns_sql}
from {oo.templated_name} as oo
left join __splink__df_concat
on oo.node_id = {uid_concat_nodes}
"""
pipeline.enqueue_sql(sql, "__splink__df_clustered_with_input_data")

df_clustered_with_input_data = db_api.sql_pipeline_to_splink_dataframe(pipeline)

oo.drop_table_from_database_and_remove_from_cache()

if threshold_match_probability is not None:
df_clustered_with_input_data.metadata["threshold_match_probability"] = (
threshold_match_probability
)

return df_clustered_with_input_data

def _compute_metrics_nodes(
self,
df_predict: SplinkDataFrame,
Expand Down
Loading
Loading