From 1c75c36c0230cfca6ca8b2fc3c97d826b4460988 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Fri, 3 Nov 2023 09:38:56 +0000 Subject: [PATCH 1/3] refactor blocking sql --- splink/analyse_blocking.py | 9 ++-- splink/blocking.py | 90 +++++++++++++++++++++-------------- splink/em_training_session.py | 7 +-- splink/estimate_u.py | 7 +-- splink/linker.py | 28 ++++++----- splink/m_training.py | 8 ++-- 6 files changed, 89 insertions(+), 60 deletions(-) diff --git a/splink/analyse_blocking.py b/splink/analyse_blocking.py index 201ec6e9c8..d4105123a5 100644 --- a/splink/analyse_blocking.py +++ b/splink/analyse_blocking.py @@ -5,7 +5,7 @@ import pandas as pd -from .blocking import BlockingRule, _sql_gen_where_condition, block_using_rules_sql +from .blocking import BlockingRule, _sql_gen_where_condition, block_using_rules_sqls from .misc import calculate_cartesian, calculate_reduction_ratio # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports @@ -83,8 +83,9 @@ def cumulative_comparisons_generated_by_blocking_rules( cartesian = calculate_cartesian(row_count_df, settings_obj._link_type) # Calculate the total number of rows generated by each blocking rule - sql = block_using_rules_sql(linker) - linker._enqueue_sql(sql, "__splink__df_blocked_data") + sqls = block_using_rules_sqls(linker) + for sql in sqls: + linker._enqueue_sql(sql["sql"], sql["output_table_name"]) brs_as_objs = linker._settings_obj_._blocking_rules_to_generate_predictions @@ -92,7 +93,7 @@ def cumulative_comparisons_generated_by_blocking_rules( select count(*) as row_count, match_key - from __splink__df_blocked_data + from __splink__df_blocked group by match_key order by cast(match_key as int) asc """ diff --git a/splink/blocking.py b/splink/blocking.py index 18e697b624..7cd7fda6c2 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -197,7 +197,7 @@ def _sql_gen_where_condition(link_type, unique_id_cols): # flake8: noqa: C901 -def block_using_rules_sql(linker: Linker): +def block_using_rules_sqls(linker: Linker): """Use the blocking rules specified in the linker's settings object to generate a SQL statement that will create pairwise record comparions according to the blocking rule(s). @@ -206,6 +206,54 @@ def block_using_rules_sql(linker: Linker): so that duplicate comparisons are not generated. """ + sqls = [] + + # For the two dataset link only, rather than a self join of + # __splink__df_concat_with_tf, it's much faster to split the input + # into two tables, and join (because then Splink doesn't have to evaluate) + # intra-dataset comparisons. + # see https://github.com/moj-analytical-services/splink/pull/1359 + if ( + linker._two_dataset_link_only + and not linker._find_new_matches_mode + and not linker._compare_two_records_mode + ): + source_dataset_col = linker._source_dataset_column_name + # Need df_l to be the one with the lowest id to preeserve the property + # that the left dataset is the one with the lowest concatenated id + keys = linker._input_tables_dict.keys() + keys = list(sorted(keys)) + df_l = linker._input_tables_dict[keys[0]] + df_r = linker._input_tables_dict[keys[1]] + + # This also needs to work for training u + if linker._train_u_using_random_sample_mode: + sample_switch = "_sample" + else: + sample_switch = "" + + sql = f""" + select * from __splink__df_concat_with_tf{sample_switch} + where {source_dataset_col} = '{df_l.templated_name}' + """ + sqls.append( + { + "sql": sql, + "output_table_name": f"__splink__df_concat_with_tf{sample_switch}_left", + } + ) + + sql = f""" + select * from __splink__df_concat_with_tf{sample_switch} + where {source_dataset_col} = '{df_r.templated_name}' + """ + sqls.append( + { + "sql": sql, + "output_table_name": f"__splink__df_concat_with_tf{sample_switch}_right", + } + ) + if type(linker).__name__ in ["SparkLinker"]: apply_salt = True else: @@ -243,36 +291,6 @@ def block_using_rules_sql(linker: Linker): " will not be implemented for this run." ) - if ( - linker._two_dataset_link_only - and not linker._find_new_matches_mode - and not linker._compare_two_records_mode - ): - source_dataset_col = linker._source_dataset_column_name - # Need df_l to be the one with the lowest id to preeserve the property - # that the left dataset is the one with the lowest concatenated id - keys = linker._input_tables_dict.keys() - keys = list(sorted(keys)) - df_l = linker._input_tables_dict[keys[0]] - df_r = linker._input_tables_dict[keys[1]] - - if linker._train_u_using_random_sample_mode: - sample_switch = "_sample" - else: - sample_switch = "" - - sql = f""" - select * from __splink__df_concat_with_tf{sample_switch} - where {source_dataset_col} = '{df_l.templated_name}' - """ - linker._enqueue_sql(sql, f"__splink__df_concat_with_tf{sample_switch}_left") - - sql = f""" - select * from __splink__df_concat_with_tf{sample_switch} - where {source_dataset_col} = '{df_r.templated_name}' - """ - linker._enqueue_sql(sql, f"__splink__df_concat_with_tf{sample_switch}_right") - # Cover the case where there are no blocking rules # This is a bit of a hack where if you do a self-join on 'true' # you create a cartesian product, rather than having separate code @@ -287,7 +305,7 @@ def block_using_rules_sql(linker: Linker): else: probability = "" - sqls = [] + br_sqls = [] for br in blocking_rules: # Apply our salted rules to resolve skew issues. If no salt was # selected to be added, then apply the initial blocking rule. @@ -310,8 +328,10 @@ def block_using_rules_sql(linker: Linker): {where_condition} """ - sqls.append(sql) + br_sqls.append(sql) + + sql = "union all".join(br_sqls) - sql = "union all".join(sqls) + sqls.append({"sql": sql, "output_table_name": "__splink__df_blocked"}) - return sql + return sqls diff --git a/splink/em_training_session.py b/splink/em_training_session.py index 3f6df5244c..37397bcebf 100644 --- a/splink/em_training_session.py +++ b/splink/em_training_session.py @@ -4,7 +4,7 @@ from copy import deepcopy from typing import TYPE_CHECKING -from .blocking import BlockingRule, block_using_rules_sql +from .blocking import BlockingRule, block_using_rules_sqls from .charts import ( m_u_parameters_interactive_history_chart, match_weights_interactive_history_chart, @@ -151,8 +151,9 @@ def _comparison_vectors(self): nodes_with_tf = self._original_linker._initialise_df_concat_with_tf() - sql = block_using_rules_sql(self._training_linker) - self._training_linker._enqueue_sql(sql, "__splink__df_blocked") + sqls = block_using_rules_sqls(self) + for sql in sqls: + self._training_linker._enqueue_sql(sql["sql"], sql["output_table_name"]) # repartition after blocking only exists on the SparkLinker repartition_after_blocking = getattr( diff --git a/splink/estimate_u.py b/splink/estimate_u.py index 027414522e..4bfe667fdf 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -4,7 +4,7 @@ from copy import deepcopy from typing import TYPE_CHECKING -from .blocking import block_using_rules_sql +from .blocking import block_using_rules_sqls from .comparison_vector_values import compute_comparison_vector_values_sql from .expectation_maximisation import ( compute_new_parameters_sql, @@ -106,8 +106,9 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None): settings_obj._blocking_rules_to_generate_predictions = [] - sql = block_using_rules_sql(training_linker) - training_linker._enqueue_sql(sql, "__splink__df_blocked") + sqls = block_using_rules_sqls(training_linker) + for sql in sqls: + training_linker._enqueue_sql(sql["sql"], sql["output_table_name"]) # repartition after blocking only exists on the SparkLinker repartition_after_blocking = getattr( diff --git a/splink/linker.py b/splink/linker.py index 897dfc9899..c95d12a9dd 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -33,7 +33,7 @@ ) from .blocking import ( BlockingRule, - block_using_rules_sql, + block_using_rules_sqls, blocking_rule_to_obj, ) from .cache_dict_with_logging import CacheDictWithLogging @@ -1406,8 +1406,9 @@ def deterministic_link(self) -> SplinkDataFrame: self._deterministic_link_mode = True concat_with_tf = self._initialise_df_concat_with_tf() - sql = block_using_rules_sql(self) - self._enqueue_sql(sql, "__splink__df_blocked") + sqls = block_using_rules_sqls(self) + for sql in sqls: + self._enqueue_sql(sql["sql"], sql["output_table_name"]) return self._execute_sql_pipeline([concat_with_tf]) def estimate_u_using_random_sampling( @@ -1728,8 +1729,9 @@ def predict( if nodes_with_tf: input_dataframes.append(nodes_with_tf) - sql = block_using_rules_sql(self) - self._enqueue_sql(sql, "__splink__df_blocked") + sqls = block_using_rules_sqls(self) + for sql in sqls: + self._enqueue_sql(sql["sql"], sql["output_table_name"]) repartition_after_blocking = getattr(self, "repartition_after_blocking", False) @@ -1853,8 +1855,9 @@ def find_matches_to_new_records( add_unique_id_and_source_dataset_cols_if_needed(self, new_records_df) - sql = block_using_rules_sql(self) - self._enqueue_sql(sql, "__splink__df_blocked") + sqls = block_using_rules_sqls(self) + for sql in sqls: + self._enqueue_sql(sql["sql"], sql["output_table_name"]) sql = compute_comparison_vector_values_sql(self._settings_obj) self._enqueue_sql(sql, "__splink__df_comparison_vectors") @@ -1937,8 +1940,9 @@ def compare_two_records(self, record_1: dict, record_2: dict): self._enqueue_sql(sql_join_tf, "__splink__compare_two_records_right_with_tf") - sql = block_using_rules_sql(self) - self._enqueue_sql(sql, "__splink__df_blocked") + sqls = block_using_rules_sqls(self) + for sql in sqls: + self._enqueue_sql(sql["sql"], sql["output_table_name"]) sql = compute_comparison_vector_values_sql(self._settings_obj) self._enqueue_sql(sql, "__splink__df_comparison_vectors") @@ -1993,9 +1997,9 @@ def _self_link(self) -> SplinkDataFrame: nodes_with_tf = self._initialise_df_concat_with_tf() - sql = block_using_rules_sql(self) - - self._enqueue_sql(sql, "__splink__df_blocked") + sqls = block_using_rules_sqls(self) + for sql in sqls: + self._enqueue_sql(sql["sql"], sql["output_table_name"]) sql = compute_comparison_vector_values_sql(self._settings_obj) diff --git a/splink/m_training.py b/splink/m_training.py index 0b026f4b64..0fe7b13140 100644 --- a/splink/m_training.py +++ b/splink/m_training.py @@ -1,7 +1,7 @@ import logging from copy import deepcopy -from .blocking import BlockingRule, block_using_rules_sql +from .blocking import BlockingRule, block_using_rules_sqls from .comparison_vector_values import compute_comparison_vector_values_sql from .expectation_maximisation import ( compute_new_parameters_sql, @@ -34,8 +34,10 @@ def estimate_m_values_from_label_column(linker, df_dict, label_colname): concat_with_tf = linker._initialise_df_concat_with_tf() - sql = block_using_rules_sql(training_linker) - training_linker._enqueue_sql(sql, "__splink__df_blocked") + sqls = block_using_rules_sqls(training_linker) + + for sql in sqls: + training_linker._enqueue_sql(sql["sql"], sql["output_table_name"]) sql = compute_comparison_vector_values_sql(settings_obj) From 0cca447f442c0e6c5f11617c1bb2b6c35f299a77 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Fri, 3 Nov 2023 09:43:02 +0000 Subject: [PATCH 2/3] line length --- splink/blocking.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/splink/blocking.py b/splink/blocking.py index 7cd7fda6c2..40ed9a1c43 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -228,29 +228,29 @@ def block_using_rules_sqls(linker: Linker): # This also needs to work for training u if linker._train_u_using_random_sample_mode: - sample_switch = "_sample" + spl_switch = "_sample" else: - sample_switch = "" + spl_switch = "" sql = f""" - select * from __splink__df_concat_with_tf{sample_switch} + select * from __splink__df_concat_with_tf{spl_switch} where {source_dataset_col} = '{df_l.templated_name}' """ sqls.append( { "sql": sql, - "output_table_name": f"__splink__df_concat_with_tf{sample_switch}_left", + "output_table_name": f"__splink__df_concat_with_tf{spl_switch}_left", } ) sql = f""" - select * from __splink__df_concat_with_tf{sample_switch} + select * from __splink__df_concat_with_tf{spl_switch} where {source_dataset_col} = '{df_r.templated_name}' """ sqls.append( { "sql": sql, - "output_table_name": f"__splink__df_concat_with_tf{sample_switch}_right", + "output_table_name": f"__splink__df_concat_with_tf{spl_switch}_right", } ) From f36f498fcb076939589defa7527f561536a16c53 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Fri, 3 Nov 2023 10:21:02 +0000 Subject: [PATCH 3/3] fix to em_training_session.py --- splink/em_training_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/em_training_session.py b/splink/em_training_session.py index 37397bcebf..4ed9706e90 100644 --- a/splink/em_training_session.py +++ b/splink/em_training_session.py @@ -151,7 +151,7 @@ def _comparison_vectors(self): nodes_with_tf = self._original_linker._initialise_df_concat_with_tf() - sqls = block_using_rules_sqls(self) + sqls = block_using_rules_sqls(self._training_linker) for sql in sqls: self._training_linker._enqueue_sql(sql["sql"], sql["output_table_name"])