Skip to content

Commit

Permalink
filter neighbours and reps
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Sep 30, 2024
1 parent 7c181aa commit 784940f
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions splink/internals/connected_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _cc_find_converged_nodes(representatives_name: str, neighbours_name: str) ->
sqls = []

sql_non_stable = f"""
SELECT DISTINCT r.node_id
SELECT DISTINCT r.representative
FROM {representatives_name} r
JOIN {neighbours_name} n ON r.node_id = n.node_id
JOIN {representatives_name} r2 ON n.neighbour = r2.node_id
Expand Down Expand Up @@ -390,9 +390,6 @@ def solve_connected_components(

prev_representatives_table = representatives

print(f"representatives {representatives.physical_name}")
representatives.as_duckdbpyrelation().show()

# Loop while our representative table still has unsettled nodes
# (nodes where the representative has changed since the last iteration)
converged_clusters_tables = []
Expand All @@ -404,6 +401,10 @@ def solve_connected_components(
iteration += 1
print("-" * 40 + f" Iteration {iteration} " + "-" * 40)

c = prev_representatives_table.as_duckdbpyrelation().count("*").fetchone()[0]
print(f"Number of representatives: {c:,.0f}")
c = filtered_neighbours.as_duckdbpyrelation().count("*").fetchone()[0]
print(f"Number of filtered neighbours: {c:,.0f}")
# Loop summary:
# 1. Find stable clusters and remove from representatives table
# Stable clusters are those where a set of nodes are within the same cluster
Expand All @@ -415,14 +416,16 @@ def solve_connected_components(
# to create the "needs_updating" column based on whether rep has changed
# 4. Assess if any representatives changed between iterations, exit if not.

# 1. Find stable clusters and remove from representatives table
# 1a. Find stable clusters and remove from representatives table
pipeline = CTEPipeline([filtered_neighbours, prev_representatives_table])
unconverged_nodes_sqls = _cc_find_converged_nodes(
representatives.templated_name, filtered_neighbours.templated_name
converged_nodes_sqls = _cc_find_converged_nodes(
prev_representatives_table.templated_name,
filtered_neighbours.templated_name,
)
pipeline.enqueue_list_of_sqls(unconverged_nodes_sqls)
pipeline.enqueue_list_of_sqls(converged_nodes_sqls)

representatives_stable = db_api.sql_pipeline_to_splink_dataframe(pipeline)

converged_clusters_tables.append(representatives_stable)

# Remove stable clusters from representatives table
Expand All @@ -437,6 +440,17 @@ def solve_connected_components(
pipeline.enqueue_sql(sql, "__splink__representatives_stable")
prev_representatives_thinned = db_api.sql_pipeline_to_splink_dataframe(pipeline)

# 1a. Thin neighbours table - we can drop all rows that refer to
# node_ids that have converged
pipeline = CTEPipeline([prev_representatives_thinned, filtered_neighbours])
sql = f"""
select * from {filtered_neighbours.templated_name}
where node_id in
(select node_id from {prev_representatives_thinned.templated_name})
"""
pipeline.enqueue_sql(sql, "__splink__df_neighbours_filtered")
filtered_neighbours = db_api.sql_pipeline_to_splink_dataframe(pipeline)

# Generates our representatives table for the next iteration
# by joining our previous tables onto our neighbours table.
pipeline = CTEPipeline([neighbours])
Expand All @@ -458,9 +472,6 @@ def solve_connected_components(

representatives = db_api.sql_pipeline_to_splink_dataframe(pipeline)

print(f"representatives {representatives.physical_name}")
representatives.as_duckdbpyrelation().show()

pipeline = CTEPipeline()
# Update table reference
prev_representatives_table.drop_table_from_database_and_remove_from_cache()
Expand Down Expand Up @@ -489,7 +500,7 @@ def solve_connected_components(

sql = " UNION ALL ".join(
[
f"""select representative as cluster_id, node_id as {node_id_column_name}
f"""select node_id as {node_id_column_name}, representative as cluster_id,
from {t.physical_name}"""
for t in converged_clusters_tables
]
Expand Down

0 comments on commit 784940f

Please sign in to comment.