Skip to content

Commit

Permalink
Merge pull request #1695 from moj-analytical-services/blocking_sql_qu…
Browse files Browse the repository at this point in the history
…eueing_refactor

Refactor `block_using_rules_sql` to follow normal pattern and avoid confusion
  • Loading branch information
RobinL authored Nov 3, 2023
2 parents 6db77fc + f36f498 commit a0240cf
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 60 deletions.
9 changes: 5 additions & 4 deletions splink/analyse_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pandas as pd

from .blocking import BlockingRule, _sql_gen_where_condition, block_using_rules_sql
from .blocking import BlockingRule, _sql_gen_where_condition, block_using_rules_sqls
from .misc import calculate_cartesian, calculate_reduction_ratio

# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
Expand Down Expand Up @@ -83,16 +83,17 @@ def cumulative_comparisons_generated_by_blocking_rules(
cartesian = calculate_cartesian(row_count_df, settings_obj._link_type)

# Calculate the total number of rows generated by each blocking rule
sql = block_using_rules_sql(linker)
linker._enqueue_sql(sql, "__splink__df_blocked_data")
sqls = block_using_rules_sqls(linker)
for sql in sqls:
linker._enqueue_sql(sql["sql"], sql["output_table_name"])

brs_as_objs = linker._settings_obj_._blocking_rules_to_generate_predictions

sql = """
select
count(*) as row_count,
match_key
from __splink__df_blocked_data
from __splink__df_blocked
group by match_key
order by cast(match_key as int) asc
"""
Expand Down
90 changes: 55 additions & 35 deletions splink/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _sql_gen_where_condition(link_type, unique_id_cols):


# flake8: noqa: C901
def block_using_rules_sql(linker: Linker):
def block_using_rules_sqls(linker: Linker):
"""Use the blocking rules specified in the linker's settings object to
generate a SQL statement that will create pairwise record comparions
according to the blocking rule(s).
Expand All @@ -206,6 +206,54 @@ def block_using_rules_sql(linker: Linker):
so that duplicate comparisons are not generated.
"""

sqls = []

# For the two dataset link only, rather than a self join of
# __splink__df_concat_with_tf, it's much faster to split the input
# into two tables, and join (because then Splink doesn't have to evaluate)
# intra-dataset comparisons.
# see https://github.com/moj-analytical-services/splink/pull/1359
if (
linker._two_dataset_link_only
and not linker._find_new_matches_mode
and not linker._compare_two_records_mode
):
source_dataset_col = linker._source_dataset_column_name
# Need df_l to be the one with the lowest id to preeserve the property
# that the left dataset is the one with the lowest concatenated id
keys = linker._input_tables_dict.keys()
keys = list(sorted(keys))
df_l = linker._input_tables_dict[keys[0]]
df_r = linker._input_tables_dict[keys[1]]

# This also needs to work for training u
if linker._train_u_using_random_sample_mode:
spl_switch = "_sample"
else:
spl_switch = ""

sql = f"""
select * from __splink__df_concat_with_tf{spl_switch}
where {source_dataset_col} = '{df_l.templated_name}'
"""
sqls.append(
{
"sql": sql,
"output_table_name": f"__splink__df_concat_with_tf{spl_switch}_left",
}
)

sql = f"""
select * from __splink__df_concat_with_tf{spl_switch}
where {source_dataset_col} = '{df_r.templated_name}'
"""
sqls.append(
{
"sql": sql,
"output_table_name": f"__splink__df_concat_with_tf{spl_switch}_right",
}
)

if type(linker).__name__ in ["SparkLinker"]:
apply_salt = True
else:
Expand Down Expand Up @@ -243,36 +291,6 @@ def block_using_rules_sql(linker: Linker):
" will not be implemented for this run."
)

if (
linker._two_dataset_link_only
and not linker._find_new_matches_mode
and not linker._compare_two_records_mode
):
source_dataset_col = linker._source_dataset_column_name
# Need df_l to be the one with the lowest id to preeserve the property
# that the left dataset is the one with the lowest concatenated id
keys = linker._input_tables_dict.keys()
keys = list(sorted(keys))
df_l = linker._input_tables_dict[keys[0]]
df_r = linker._input_tables_dict[keys[1]]

if linker._train_u_using_random_sample_mode:
sample_switch = "_sample"
else:
sample_switch = ""

sql = f"""
select * from __splink__df_concat_with_tf{sample_switch}
where {source_dataset_col} = '{df_l.templated_name}'
"""
linker._enqueue_sql(sql, f"__splink__df_concat_with_tf{sample_switch}_left")

sql = f"""
select * from __splink__df_concat_with_tf{sample_switch}
where {source_dataset_col} = '{df_r.templated_name}'
"""
linker._enqueue_sql(sql, f"__splink__df_concat_with_tf{sample_switch}_right")

# Cover the case where there are no blocking rules
# This is a bit of a hack where if you do a self-join on 'true'
# you create a cartesian product, rather than having separate code
Expand All @@ -287,7 +305,7 @@ def block_using_rules_sql(linker: Linker):
else:
probability = ""

sqls = []
br_sqls = []
for br in blocking_rules:
# Apply our salted rules to resolve skew issues. If no salt was
# selected to be added, then apply the initial blocking rule.
Expand All @@ -310,8 +328,10 @@ def block_using_rules_sql(linker: Linker):
{where_condition}
"""

sqls.append(sql)
br_sqls.append(sql)

sql = "union all".join(br_sqls)

sql = "union all".join(sqls)
sqls.append({"sql": sql, "output_table_name": "__splink__df_blocked"})

return sql
return sqls
7 changes: 4 additions & 3 deletions splink/em_training_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
from typing import TYPE_CHECKING

from .blocking import BlockingRule, block_using_rules_sql
from .blocking import BlockingRule, block_using_rules_sqls
from .charts import (
m_u_parameters_interactive_history_chart,
match_weights_interactive_history_chart,
Expand Down Expand Up @@ -151,8 +151,9 @@ def _comparison_vectors(self):

nodes_with_tf = self._original_linker._initialise_df_concat_with_tf()

sql = block_using_rules_sql(self._training_linker)
self._training_linker._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self._training_linker)
for sql in sqls:
self._training_linker._enqueue_sql(sql["sql"], sql["output_table_name"])

# repartition after blocking only exists on the SparkLinker
repartition_after_blocking = getattr(
Expand Down
7 changes: 4 additions & 3 deletions splink/estimate_u.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
from typing import TYPE_CHECKING

from .blocking import block_using_rules_sql
from .blocking import block_using_rules_sqls
from .comparison_vector_values import compute_comparison_vector_values_sql
from .expectation_maximisation import (
compute_new_parameters_sql,
Expand Down Expand Up @@ -106,8 +106,9 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None):

settings_obj._blocking_rules_to_generate_predictions = []

sql = block_using_rules_sql(training_linker)
training_linker._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(training_linker)
for sql in sqls:
training_linker._enqueue_sql(sql["sql"], sql["output_table_name"])

# repartition after blocking only exists on the SparkLinker
repartition_after_blocking = getattr(
Expand Down
28 changes: 16 additions & 12 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from .blocking import (
BlockingRule,
block_using_rules_sql,
block_using_rules_sqls,
blocking_rule_to_obj,
)
from .cache_dict_with_logging import CacheDictWithLogging
Expand Down Expand Up @@ -1406,8 +1406,9 @@ def deterministic_link(self) -> SplinkDataFrame:
self._deterministic_link_mode = True

concat_with_tf = self._initialise_df_concat_with_tf()
sql = block_using_rules_sql(self)
self._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])
return self._execute_sql_pipeline([concat_with_tf])

def estimate_u_using_random_sampling(
Expand Down Expand Up @@ -1728,8 +1729,9 @@ def predict(
if nodes_with_tf:
input_dataframes.append(nodes_with_tf)

sql = block_using_rules_sql(self)
self._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])

repartition_after_blocking = getattr(self, "repartition_after_blocking", False)

Expand Down Expand Up @@ -1853,8 +1855,9 @@ def find_matches_to_new_records(

add_unique_id_and_source_dataset_cols_if_needed(self, new_records_df)

sql = block_using_rules_sql(self)
self._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])

sql = compute_comparison_vector_values_sql(self._settings_obj)
self._enqueue_sql(sql, "__splink__df_comparison_vectors")
Expand Down Expand Up @@ -1937,8 +1940,9 @@ def compare_two_records(self, record_1: dict, record_2: dict):

self._enqueue_sql(sql_join_tf, "__splink__compare_two_records_right_with_tf")

sql = block_using_rules_sql(self)
self._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])

sql = compute_comparison_vector_values_sql(self._settings_obj)
self._enqueue_sql(sql, "__splink__df_comparison_vectors")
Expand Down Expand Up @@ -1993,9 +1997,9 @@ def _self_link(self) -> SplinkDataFrame:

nodes_with_tf = self._initialise_df_concat_with_tf()

sql = block_using_rules_sql(self)

self._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])

sql = compute_comparison_vector_values_sql(self._settings_obj)

Expand Down
8 changes: 5 additions & 3 deletions splink/m_training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from copy import deepcopy

from .blocking import BlockingRule, block_using_rules_sql
from .blocking import BlockingRule, block_using_rules_sqls
from .comparison_vector_values import compute_comparison_vector_values_sql
from .expectation_maximisation import (
compute_new_parameters_sql,
Expand Down Expand Up @@ -34,8 +34,10 @@ def estimate_m_values_from_label_column(linker, df_dict, label_colname):

concat_with_tf = linker._initialise_df_concat_with_tf()

sql = block_using_rules_sql(training_linker)
training_linker._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(training_linker)

for sql in sqls:
training_linker._enqueue_sql(sql["sql"], sql["output_table_name"])

sql = compute_comparison_vector_values_sql(settings_obj)

Expand Down

0 comments on commit a0240cf

Please sign in to comment.