diff --git a/docs/demos/tutorials/03_Blocking.ipynb b/docs/demos/tutorials/03_Blocking.ipynb index 929a13490a..83e86a4418 100644 --- a/docs/demos/tutorials/03_Blocking.ipynb +++ b/docs/demos/tutorials/03_Blocking.ipynb @@ -153,19 +153,19 @@ "\n", "blocking_rule_1 = block_on([\"substr(first_name, 1,1)\", \"surname\"])\n", "count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_1)\n", - "print(f\"Number of comparisons generated by '{blocking_rule_1.sql}': {count:,.0f}\")\n", + "print(f\"Number of comparisons generated by '{blocking_rule_1.blocking_rule_sql}': {count:,.0f}\")\n", "\n", "blocking_rule_2 = block_on(\"surname\")\n", "count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_2)\n", - "print(f\"Number of comparisons generated by '{blocking_rule_2.sql}': {count:,.0f}\")\n", + "print(f\"Number of comparisons generated by '{blocking_rule_2.blocking_rule_sql}': {count:,.0f}\")\n", "\n", "blocking_rule_3 = block_on(\"email\")\n", "count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_3)\n", - "print(f\"Number of comparisons generated by '{blocking_rule_3.sql}': {count:,.0f}\")\n", + "print(f\"Number of comparisons generated by '{blocking_rule_3.blocking_rule_sql}': {count:,.0f}\")\n", "\n", "blocking_rule_4 = block_on([\"city\", \"first_name\"])\n", "count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_4)\n", - "print(f\"Number of comparisons generated by '{blocking_rule_4.sql}': {count:,.0f}\")\n" + "print(f\"Number of comparisons generated by '{blocking_rule_4.blocking_rule_sql}': {count:,.0f}\")\n" ] }, { diff --git a/splink/accuracy.py b/splink/accuracy.py index d2b140ff08..d8a92a0491 100644 --- a/splink/accuracy.py +++ b/splink/accuracy.py @@ -1,4 +1,5 @@ from copy import deepcopy +from typing import TYPE_CHECKING from .block_from_labels import block_from_labels from .blocking import BlockingRule @@ -6,6 +7,9 @@ from .predict import predict_from_comparison_vectors_sqls from .sql_transform import move_l_r_table_prefix_to_column_suffix +if TYPE_CHECKING: + from .linker import Linker + def truth_space_table_from_labels_with_predictions_sqls( threshold_actual=0.5, match_weight_round_to_nearest=None @@ -143,10 +147,11 @@ def truth_space_table_from_labels_with_predictions_sqls( return sqls -def _select_found_by_blocking_rules(linker): +def _select_found_by_blocking_rules(linker: "Linker"): brs = linker._settings_obj._blocking_rules_to_generate_predictions + if brs: - brs = [move_l_r_table_prefix_to_column_suffix(b.blocking_rule) for b in brs] + brs = [move_l_r_table_prefix_to_column_suffix(b.blocking_rule_sql) for b in brs] brs = [f"(coalesce({b}, false))" for b in brs] brs = " OR ".join(brs) br_col = f" ({brs}) " diff --git a/splink/analyse_blocking.py b/splink/analyse_blocking.py index d4105123a5..d2f70ecbe3 100644 --- a/splink/analyse_blocking.py +++ b/splink/analyse_blocking.py @@ -118,7 +118,7 @@ def cumulative_comparisons_generated_by_blocking_rules( for row, br in zip(br_count, brs_as_objs): out_dict = { "row_count": row, - "rule": br.blocking_rule, + "rule": br.blocking_rule_sql, } if output_chart: cumulative_sum += row diff --git a/splink/blocking.py b/splink/blocking.py index 40ed9a1c43..303a4c6622 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -3,7 +3,7 @@ from sqlglot import parse_one from sqlglot.expressions import Join, Column from sqlglot.optimizer.eliminate_joins import join_condition -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, List import logging from .misc import ensure_is_list @@ -40,15 +40,20 @@ def blocking_rule_to_obj(br): class BlockingRule: def __init__( self, - blocking_rule: BlockingRule | dict | str, + blocking_rule_sql: str, salting_partitions=1, sqlglot_dialect: str = None, ): if sqlglot_dialect: self._sql_dialect = sqlglot_dialect - self.blocking_rule = blocking_rule - self.preceding_rules = [] + # Temporarily just to see if tests still pass + if not isinstance(blocking_rule_sql, str): + raise ValueError( + f"Blocking rule must be a string, not {type(blocking_rule_sql)}" + ) + self.blocking_rule_sql = blocking_rule_sql + self.preceding_rules: List[BlockingRule] = [] self.sqlglot_dialect = sqlglot_dialect self.salting_partitions = salting_partitions @@ -60,25 +65,30 @@ def sql_dialect(self): def match_key(self): return len(self.preceding_rules) - @property - def sql(self): - # Wrapper to reveal the underlying SQL - return self.blocking_rule - def add_preceding_rules(self, rules): rules = ensure_is_list(rules) self.preceding_rules = rules - @property - def and_not_preceding_rules_sql(self): - if not self.preceding_rules: - return "" + def exclude_pairs_generated_by_this_rule_sql(self): + """A SQL string specifying how to exclude the results + of THIS blocking rule from subseqent blocking statements, + so that subsequent statements do not produce duplicate pairs + """ # Note the coalesce function is important here - otherwise # you filter out any records with nulls in the previous rules # meaning these comparisons get lost + return f"coalesce(({self.blocking_rule_sql}),false)" + + @property + def exclude_pairs_generated_by_all_preceding_rules_sql(self): + """A SQL string that excludes the results of ALL previous blocking rules from + the pairwise comparisons generated. + """ + if not self.preceding_rules: + return "" or_clauses = [ - f"coalesce(({r.blocking_rule}), false)" for r in self.preceding_rules + br.exclude_pairs_generated_by_this_rule_sql() for br in self.preceding_rules ] previous_rules = " OR ".join(or_clauses) return f"AND NOT ({previous_rules})" @@ -86,14 +96,14 @@ def and_not_preceding_rules_sql(self): @property def salted_blocking_rules(self): if self.salting_partitions == 1: - yield self.blocking_rule + yield self.blocking_rule_sql else: for n in range(self.salting_partitions): - yield f"{self.blocking_rule} and ceiling(l.__splink_salt * {self.salting_partitions}) = {n+1}" # noqa: E501 + yield f"{self.blocking_rule_sql} and ceiling(l.__splink_salt * {self.salting_partitions}) = {n+1}" # noqa: E501 @property def _parsed_join_condition(self): - br = self.blocking_rule + br = self.blocking_rule_sql return parse_one("INNER JOIN r", into=Join).on( br, dialect=self.sqlglot_dialect ) # using sqlglot==11.4.1 @@ -147,7 +157,7 @@ def as_dict(self): "The minimal representation of the blocking rule" output = {} - output["blocking_rule"] = self.blocking_rule + output["blocking_rule"] = self.blocking_rule_sql output["sql_dialect"] = self.sql_dialect if self.salting_partitions > 1 and self.sql_dialect == "spark": @@ -157,7 +167,7 @@ def as_dict(self): def _as_completed_dict(self): if not self.salting_partitions > 1 and self.sql_dialect == "spark": - return self.blocking_rule + return self.blocking_rule_sql else: return self.as_dict() @@ -166,7 +176,7 @@ def descr(self): return "Custom" if not hasattr(self, "_description") else self._description def _abbreviated_sql(self, cutoff=75): - sql = self.blocking_rule + sql = self.blocking_rule_sql return (sql[:cutoff] + "...") if len(sql) > cutoff else sql def __repr__(self): @@ -312,7 +322,7 @@ def block_using_rules_sqls(linker: Linker): if apply_salt: salted_blocking_rules = br.salted_blocking_rules else: - salted_blocking_rules = [br.blocking_rule] + salted_blocking_rules = [br.blocking_rule_sql] for salted_br in salted_blocking_rules: sql = f""" @@ -324,7 +334,7 @@ def block_using_rules_sqls(linker: Linker): inner join {linker._input_tablename_r} as r on ({salted_br}) - {br.and_not_preceding_rules_sql} + {br.exclude_pairs_generated_by_all_preceding_rules_sql} {where_condition} """ diff --git a/splink/blocking_rule_composition.py b/splink/blocking_rule_composition.py index e6c4ce38f4..29cb01d9f6 100644 --- a/splink/blocking_rule_composition.py +++ b/splink/blocking_rule_composition.py @@ -295,7 +295,7 @@ def not_(*brls: BlockingRule | dict | str, salting_partitions: int = 1) -> Block brls, sql_dialect, salt = _parse_blocking_rules(*brls) br = brls[0] - blocking_rule = f"NOT ({br.blocking_rule})" + blocking_rule = f"NOT ({br.blocking_rule_sql})" return BlockingRule( blocking_rule, @@ -314,9 +314,9 @@ def _br_merge( brs, sql_dialect, salt = _parse_blocking_rules(*brls) if len(brs) > 1: - conditions = (f"({br.blocking_rule})" for br in brs) + conditions = (f"({br.blocking_rule_sql})" for br in brs) else: - conditions = (br.blocking_rule for br in brs) + conditions = (br.blocking_rule_sql for br in brs) blocking_rule = f" {clause} ".join(conditions) diff --git a/splink/em_training_session.py b/splink/em_training_session.py index 4ed9706e90..aaf01c986e 100644 --- a/splink/em_training_session.py +++ b/splink/em_training_session.py @@ -135,7 +135,7 @@ def _training_log_message(self): else: mu = "m and u probabilities" - blocking_rule = self._blocking_rule_for_training.blocking_rule + blocking_rule = self._blocking_rule_for_training.blocking_rule_sql logger.info( f"Estimating the {mu} of the model by blocking on:\n" @@ -176,7 +176,7 @@ def _train(self): # check that the blocking rule actually generates _some_ record pairs, # if not give the user a helpful message if not cvv.as_record_dict(limit=1): - br_sql = f"`{self._blocking_rule_for_training.blocking_rule}`" + br_sql = f"`{self._blocking_rule_for_training.blocking_rule_sql}`" raise EMTrainingException( f"Training rule {br_sql} resulted in no record pairs. " "This means that in the supplied data set " @@ -195,7 +195,7 @@ def _train(self): # in the original (main) setting object expectation_maximisation(self, cvv) - rule = self._blocking_rule_for_training.blocking_rule + rule = self._blocking_rule_for_training.blocking_rule_sql training_desc = f"EM, blocked on: {rule}" # Add m and u values to original settings @@ -254,7 +254,7 @@ def _blocking_adjusted_probability_two_random_records_match(self): comp_levels = self._comparison_levels_to_reverse_blocking_rule if not comp_levels: comp_levels = self._original_settings_obj._get_comparison_levels_corresponding_to_training_blocking_rule( # noqa - self._blocking_rule_for_training.blocking_rule + self._blocking_rule_for_training.blocking_rule_sql ) for cl in comp_levels: @@ -271,7 +271,7 @@ def _blocking_adjusted_probability_two_random_records_match(self): logger.log( 15, f"\nProb two random records match adjusted for blocking on " - f"{self._blocking_rule_for_training.blocking_rule}: " + f"{self._blocking_rule_for_training.blocking_rule_sql}: " f"{adjusted_prop_m:.3f}", ) return adjusted_prop_m @@ -411,7 +411,7 @@ def __repr__(self): for cc in self._comparisons_that_cannot_be_estimated ] ) - blocking_rule = self._blocking_rule_for_training.blocking_rule + blocking_rule = self._blocking_rule_for_training.blocking_rule_sql return ( f"" diff --git a/splink/linker.py b/splink/linker.py index 2c4e1080c9..c4a96b4a87 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -994,7 +994,7 @@ def _populate_probability_two_random_records_match_from_trained_values(self): 15, "\n" f"Probability two random records match from trained model blocking on " - f"{em_training_session._blocking_rule_for_training.blocking_rule}: " + f"{em_training_session._blocking_rule_for_training.blocking_rule_sql}: " f"{training_lambda:,.3f}", ) @@ -1630,7 +1630,7 @@ def estimate_parameters_using_expectation_maximisation( self._initialise_df_concat_with_tf() # Extract the blocking rule - blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule + blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule_sql if comparisons_to_deactivate: # If user provided a string, convert to Comparison object @@ -3100,7 +3100,7 @@ def count_num_comparisons_from_blocking_rule( int: The number of comparisons generated by the blocking rule """ - blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule + blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule_sql sql = vertically_concatenate_sql(self) self._enqueue_sql(sql, "__splink__df_concat") diff --git a/splink/settings.py b/splink/settings.py index 90cc165dfc..f14b2d79c1 100644 --- a/splink/settings.py +++ b/splink/settings.py @@ -2,8 +2,9 @@ import logging from copy import deepcopy +from typing import List -from .blocking import blocking_rule_to_obj +from .blocking import BlockingRule, blocking_rule_to_obj from .charts import m_u_parameters_chart, match_weights_chart from .comparison import Comparison from .comparison_level import ComparisonLevel @@ -125,7 +126,7 @@ def _get_additional_columns_to_retain(self): used_by_brs = [] for br in self._blocking_rules_to_generate_predictions: used_by_brs.extend( - get_columns_used_from_sql(br.blocking_rule, br.sql_dialect) + get_columns_used_from_sql(br.blocking_rule_sql, br.sql_dialect) ) used_by_brs = [InputColumn(c) for c in used_by_brs] @@ -300,7 +301,7 @@ def _get_comparison_by_output_column_name(self, name): return cc raise ValueError(f"No comparison column with name {name}") - def _brs_as_objs(self, brs_as_strings): + def _brs_as_objs(self, brs_as_strings) -> List[BlockingRule]: brs_as_objs = [blocking_rule_to_obj(br) for br in brs_as_strings] for n, br in enumerate(brs_as_objs): br.add_preceding_rules(brs_as_objs[:n]) diff --git a/splink/settings_validation/settings_validator.py b/splink/settings_validation/settings_validator.py index 06e24acfb9..ca7fe754c0 100644 --- a/splink/settings_validation/settings_validator.py +++ b/splink/settings_validation/settings_validator.py @@ -4,6 +4,7 @@ import re from functools import reduce from operator import and_ +from typing import List import sqlglot @@ -49,9 +50,9 @@ def uid(self): return self.clean_list_of_column_names(uid_as_tree) @property - def blocking_rules(self): + def blocking_rules(self) -> List[str]: brs = self.settings_obj._blocking_rules_to_generate_predictions - return [br.blocking_rule for br in brs] + return [br.blocking_rule_sql for br in brs] @property def comparisons(self): diff --git a/tests/test_blocking.py b/tests/test_blocking.py index 0eb2113cc5..e84297ddc0 100644 --- a/tests/test_blocking.py +++ b/tests/test_blocking.py @@ -19,7 +19,7 @@ def test_binary_composition_internals_OR(test_helpers, dialect): assert br_surname.__repr__() == exp_txt.format("Exact match", em_rule) assert BlockingRule(em_rule).__repr__() == exp_txt.format("Custom", em_rule) - assert br_surname.blocking_rule == em_rule + assert br_surname.blocking_rule_sql == em_rule assert br_surname.salting_partitions == 4 assert br_surname.preceding_rules == [] @@ -40,13 +40,13 @@ def test_binary_composition_internals_OR(test_helpers, dialect): brl.exact_match_rule("help4"), ] brs_as_objs = settings_tester._brs_as_objs(brs_as_strings) - brs_as_txt = [blocking_rule_to_obj(br).blocking_rule for br in brs_as_strings] + brs_as_txt = [blocking_rule_to_obj(br).blocking_rule_sql for br in brs_as_strings] assert brs_as_objs[0].preceding_rules == [] def assess_preceding_rules(settings_brs_index): br_prec = brs_as_objs[settings_brs_index].preceding_rules - br_prec_txt = [br.blocking_rule for br in br_prec] + br_prec_txt = [br.blocking_rule_sql for br in br_prec] assert br_prec_txt == brs_as_txt[:settings_brs_index] assess_preceding_rules(1) diff --git a/tests/test_blocking_rule_composition.py b/tests/test_blocking_rule_composition.py index fadc9d26bd..3e02c89d5b 100644 --- a/tests/test_blocking_rule_composition.py +++ b/tests/test_blocking_rule_composition.py @@ -11,7 +11,7 @@ def binary_composition_internals(clause, comp_fun, brl, dialect): # Test what happens when only one value is fed # It should just report the regular outputs of our comparison level func level = comp_fun(brl.exact_match_rule("tom")) - assert level.blocking_rule == f"l.{q}tom{q} = r.{q}tom{q}" + assert level.blocking_rule_sql == f"l.{q}tom{q} = r.{q}tom{q}" # Exact match and null level composition level = comp_fun( @@ -19,12 +19,12 @@ def binary_composition_internals(clause, comp_fun, brl, dialect): brl.exact_match_rule("surname"), ) exact_match_sql = f"(l.{q}first_name{q} = r.{q}first_name{q}) {clause} (l.{q}surname{q} = r.{q}surname{q})" # noqa: E501 - assert level.blocking_rule == exact_match_sql + assert level.blocking_rule_sql == exact_match_sql # brl.not_(or_(...)) composition level = brl.not_( comp_fun(brl.exact_match_rule("first_name"), brl.exact_match_rule("surname")), ) - assert level.blocking_rule == f"NOT ({exact_match_sql})" + assert level.blocking_rule_sql == f"NOT ({exact_match_sql})" # Check salting outputs # salting included in the composition function diff --git a/tests/test_u_train.py b/tests/test_u_train.py index 453e13beae..274701ebcd 100644 --- a/tests/test_u_train.py +++ b/tests/test_u_train.py @@ -39,7 +39,7 @@ def test_u_train(test_helpers, dialect): assert cl_no.u_probability == (denom - 2) / denom br = linker._settings_obj._blocking_rules_to_generate_predictions[0] - assert br.blocking_rule == "l.name = r.name" + assert br.blocking_rule_sql == "l.name = r.name" @mark_with_dialects_excluding()