From 5a7188e9fd00071fa161e30ddc669fca0f92ffa0 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 6 Nov 2023 16:45:16 +0000 Subject: [PATCH 01/11] rename sql property --- splink/analyse_blocking.py | 2 +- splink/blocking.py | 23 ++++++++----------- splink/blocking_rule_composition.py | 6 ++--- splink/em_training_session.py | 12 +++++----- splink/linker.py | 4 ++-- splink/settings.py | 2 +- .../settings_validation/settings_validator.py | 2 +- tests/test_blocking.py | 2 +- 8 files changed, 24 insertions(+), 29 deletions(-) diff --git a/splink/analyse_blocking.py b/splink/analyse_blocking.py index d4105123a5..d2f70ecbe3 100644 --- a/splink/analyse_blocking.py +++ b/splink/analyse_blocking.py @@ -118,7 +118,7 @@ def cumulative_comparisons_generated_by_blocking_rules( for row, br in zip(br_count, brs_as_objs): out_dict = { "row_count": row, - "rule": br.blocking_rule, + "rule": br.blocking_rule_sql, } if output_chart: cumulative_sum += row diff --git a/splink/blocking.py b/splink/blocking.py index 40ed9a1c43..acd4369d24 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -40,14 +40,14 @@ def blocking_rule_to_obj(br): class BlockingRule: def __init__( self, - blocking_rule: BlockingRule | dict | str, + blocking_rule_sql: BlockingRule | dict | str, salting_partitions=1, sqlglot_dialect: str = None, ): if sqlglot_dialect: self._sql_dialect = sqlglot_dialect - self.blocking_rule = blocking_rule + self.blocking_rule_sql = blocking_rule_sql self.preceding_rules = [] self.sqlglot_dialect = sqlglot_dialect self.salting_partitions = salting_partitions @@ -60,11 +60,6 @@ def sql_dialect(self): def match_key(self): return len(self.preceding_rules) - @property - def sql(self): - # Wrapper to reveal the underlying SQL - return self.blocking_rule - def add_preceding_rules(self, rules): rules = ensure_is_list(rules) self.preceding_rules = rules @@ -86,14 +81,14 @@ def and_not_preceding_rules_sql(self): @property def salted_blocking_rules(self): if self.salting_partitions == 1: - yield self.blocking_rule + yield self.blocking_rule_sql else: for n in range(self.salting_partitions): - yield f"{self.blocking_rule} and ceiling(l.__splink_salt * {self.salting_partitions}) = {n+1}" # noqa: E501 + yield f"{self.blocking_rule_sql} and ceiling(l.__splink_salt * {self.salting_partitions}) = {n+1}" # noqa: E501 @property def _parsed_join_condition(self): - br = self.blocking_rule + br = self.blocking_rule_sql return parse_one("INNER JOIN r", into=Join).on( br, dialect=self.sqlglot_dialect ) # using sqlglot==11.4.1 @@ -147,7 +142,7 @@ def as_dict(self): "The minimal representation of the blocking rule" output = {} - output["blocking_rule"] = self.blocking_rule + output["blocking_rule"] = self.blocking_rule_sql output["sql_dialect"] = self.sql_dialect if self.salting_partitions > 1 and self.sql_dialect == "spark": @@ -157,7 +152,7 @@ def as_dict(self): def _as_completed_dict(self): if not self.salting_partitions > 1 and self.sql_dialect == "spark": - return self.blocking_rule + return self.blocking_rule_sql else: return self.as_dict() @@ -166,7 +161,7 @@ def descr(self): return "Custom" if not hasattr(self, "_description") else self._description def _abbreviated_sql(self, cutoff=75): - sql = self.blocking_rule + sql = self.blocking_rule_sql return (sql[:cutoff] + "...") if len(sql) > cutoff else sql def __repr__(self): @@ -312,7 +307,7 @@ def block_using_rules_sqls(linker: Linker): if apply_salt: salted_blocking_rules = br.salted_blocking_rules else: - salted_blocking_rules = [br.blocking_rule] + salted_blocking_rules = [br.blocking_rule_sql] for salted_br in salted_blocking_rules: sql = f""" diff --git a/splink/blocking_rule_composition.py b/splink/blocking_rule_composition.py index e6c4ce38f4..29cb01d9f6 100644 --- a/splink/blocking_rule_composition.py +++ b/splink/blocking_rule_composition.py @@ -295,7 +295,7 @@ def not_(*brls: BlockingRule | dict | str, salting_partitions: int = 1) -> Block brls, sql_dialect, salt = _parse_blocking_rules(*brls) br = brls[0] - blocking_rule = f"NOT ({br.blocking_rule})" + blocking_rule = f"NOT ({br.blocking_rule_sql})" return BlockingRule( blocking_rule, @@ -314,9 +314,9 @@ def _br_merge( brs, sql_dialect, salt = _parse_blocking_rules(*brls) if len(brs) > 1: - conditions = (f"({br.blocking_rule})" for br in brs) + conditions = (f"({br.blocking_rule_sql})" for br in brs) else: - conditions = (br.blocking_rule for br in brs) + conditions = (br.blocking_rule_sql for br in brs) blocking_rule = f" {clause} ".join(conditions) diff --git a/splink/em_training_session.py b/splink/em_training_session.py index 4ed9706e90..aaf01c986e 100644 --- a/splink/em_training_session.py +++ b/splink/em_training_session.py @@ -135,7 +135,7 @@ def _training_log_message(self): else: mu = "m and u probabilities" - blocking_rule = self._blocking_rule_for_training.blocking_rule + blocking_rule = self._blocking_rule_for_training.blocking_rule_sql logger.info( f"Estimating the {mu} of the model by blocking on:\n" @@ -176,7 +176,7 @@ def _train(self): # check that the blocking rule actually generates _some_ record pairs, # if not give the user a helpful message if not cvv.as_record_dict(limit=1): - br_sql = f"`{self._blocking_rule_for_training.blocking_rule}`" + br_sql = f"`{self._blocking_rule_for_training.blocking_rule_sql}`" raise EMTrainingException( f"Training rule {br_sql} resulted in no record pairs. " "This means that in the supplied data set " @@ -195,7 +195,7 @@ def _train(self): # in the original (main) setting object expectation_maximisation(self, cvv) - rule = self._blocking_rule_for_training.blocking_rule + rule = self._blocking_rule_for_training.blocking_rule_sql training_desc = f"EM, blocked on: {rule}" # Add m and u values to original settings @@ -254,7 +254,7 @@ def _blocking_adjusted_probability_two_random_records_match(self): comp_levels = self._comparison_levels_to_reverse_blocking_rule if not comp_levels: comp_levels = self._original_settings_obj._get_comparison_levels_corresponding_to_training_blocking_rule( # noqa - self._blocking_rule_for_training.blocking_rule + self._blocking_rule_for_training.blocking_rule_sql ) for cl in comp_levels: @@ -271,7 +271,7 @@ def _blocking_adjusted_probability_two_random_records_match(self): logger.log( 15, f"\nProb two random records match adjusted for blocking on " - f"{self._blocking_rule_for_training.blocking_rule}: " + f"{self._blocking_rule_for_training.blocking_rule_sql}: " f"{adjusted_prop_m:.3f}", ) return adjusted_prop_m @@ -411,7 +411,7 @@ def __repr__(self): for cc in self._comparisons_that_cannot_be_estimated ] ) - blocking_rule = self._blocking_rule_for_training.blocking_rule + blocking_rule = self._blocking_rule_for_training.blocking_rule_sql return ( f"" diff --git a/splink/linker.py b/splink/linker.py index 7b520efd7b..6e20bbd5b5 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1630,7 +1630,7 @@ def estimate_parameters_using_expectation_maximisation( self._initialise_df_concat_with_tf() # Extract the blocking rule - blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule + blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule_sql if comparisons_to_deactivate: # If user provided a string, convert to Comparison object @@ -3100,7 +3100,7 @@ def count_num_comparisons_from_blocking_rule( int: The number of comparisons generated by the blocking rule """ - blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule + blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule_sql sql = vertically_concatenate_sql(self) self._enqueue_sql(sql, "__splink__df_concat") diff --git a/splink/settings.py b/splink/settings.py index 90cc165dfc..261eb9e886 100644 --- a/splink/settings.py +++ b/splink/settings.py @@ -125,7 +125,7 @@ def _get_additional_columns_to_retain(self): used_by_brs = [] for br in self._blocking_rules_to_generate_predictions: used_by_brs.extend( - get_columns_used_from_sql(br.blocking_rule, br.sql_dialect) + get_columns_used_from_sql(br.blocking_rule_sql, br.sql_dialect) ) used_by_brs = [InputColumn(c) for c in used_by_brs] diff --git a/splink/settings_validation/settings_validator.py b/splink/settings_validation/settings_validator.py index 06e24acfb9..a4b84743f8 100644 --- a/splink/settings_validation/settings_validator.py +++ b/splink/settings_validation/settings_validator.py @@ -51,7 +51,7 @@ def uid(self): @property def blocking_rules(self): brs = self.settings_obj._blocking_rules_to_generate_predictions - return [br.blocking_rule for br in brs] + return [br.blocking_rule_sql for br in brs] @property def comparisons(self): diff --git a/tests/test_blocking.py b/tests/test_blocking.py index 0eb2113cc5..fd01645275 100644 --- a/tests/test_blocking.py +++ b/tests/test_blocking.py @@ -40,7 +40,7 @@ def test_binary_composition_internals_OR(test_helpers, dialect): brl.exact_match_rule("help4"), ] brs_as_objs = settings_tester._brs_as_objs(brs_as_strings) - brs_as_txt = [blocking_rule_to_obj(br).blocking_rule for br in brs_as_strings] + brs_as_txt = [blocking_rule_to_obj(br).blocking_rule_sql for br in brs_as_strings] assert brs_as_objs[0].preceding_rules == [] From 64556ae3ea3752b9ad8a9ad3401deda4a6183840 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 6 Nov 2023 16:51:26 +0000 Subject: [PATCH 02/11] type hinting --- splink/accuracy.py | 6 ++++-- splink/settings.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/splink/accuracy.py b/splink/accuracy.py index d2b140ff08..5bd8c75c4e 100644 --- a/splink/accuracy.py +++ b/splink/accuracy.py @@ -3,6 +3,7 @@ from .block_from_labels import block_from_labels from .blocking import BlockingRule from .comparison_vector_values import compute_comparison_vector_values_sql +from .linker import Linker from .predict import predict_from_comparison_vectors_sqls from .sql_transform import move_l_r_table_prefix_to_column_suffix @@ -143,10 +144,11 @@ def truth_space_table_from_labels_with_predictions_sqls( return sqls -def _select_found_by_blocking_rules(linker): +def _select_found_by_blocking_rules(linker: Linker): brs = linker._settings_obj._blocking_rules_to_generate_predictions + if brs: - brs = [move_l_r_table_prefix_to_column_suffix(b.blocking_rule) for b in brs] + brs = [move_l_r_table_prefix_to_column_suffix(b.blocking_rule_sql) for b in brs] brs = [f"(coalesce({b}, false))" for b in brs] brs = " OR ".join(brs) br_col = f" ({brs}) " diff --git a/splink/settings.py b/splink/settings.py index 261eb9e886..f14b2d79c1 100644 --- a/splink/settings.py +++ b/splink/settings.py @@ -2,8 +2,9 @@ import logging from copy import deepcopy +from typing import List -from .blocking import blocking_rule_to_obj +from .blocking import BlockingRule, blocking_rule_to_obj from .charts import m_u_parameters_chart, match_weights_chart from .comparison import Comparison from .comparison_level import ComparisonLevel @@ -300,7 +301,7 @@ def _get_comparison_by_output_column_name(self, name): return cc raise ValueError(f"No comparison column with name {name}") - def _brs_as_objs(self, brs_as_strings): + def _brs_as_objs(self, brs_as_strings) -> List[BlockingRule]: brs_as_objs = [blocking_rule_to_obj(br) for br in brs_as_strings] for n, br in enumerate(brs_as_objs): br.add_preceding_rules(brs_as_objs[:n]) From 189e1e5aba3cd59791cf9f9d7501378be049ed5b Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 6 Nov 2023 16:55:36 +0000 Subject: [PATCH 03/11] fixes and typing --- splink/accuracy.py | 9 ++++++--- splink/blocking.py | 6 +++--- splink/settings_validation/settings_validator.py | 3 ++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/splink/accuracy.py b/splink/accuracy.py index 5bd8c75c4e..4512b20492 100644 --- a/splink/accuracy.py +++ b/splink/accuracy.py @@ -1,12 +1,15 @@ from copy import deepcopy - +from typing import TYPE_CHECKING from .block_from_labels import block_from_labels from .blocking import BlockingRule from .comparison_vector_values import compute_comparison_vector_values_sql -from .linker import Linker + from .predict import predict_from_comparison_vectors_sqls from .sql_transform import move_l_r_table_prefix_to_column_suffix +if TYPE_CHECKING: + from .linker import Linker + def truth_space_table_from_labels_with_predictions_sqls( threshold_actual=0.5, match_weight_round_to_nearest=None @@ -144,7 +147,7 @@ def truth_space_table_from_labels_with_predictions_sqls( return sqls -def _select_found_by_blocking_rules(linker: Linker): +def _select_found_by_blocking_rules(linker: "Linker"): brs = linker._settings_obj._blocking_rules_to_generate_predictions if brs: diff --git a/splink/blocking.py b/splink/blocking.py index acd4369d24..80f4e3d082 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -3,7 +3,7 @@ from sqlglot import parse_one from sqlglot.expressions import Join, Column from sqlglot.optimizer.eliminate_joins import join_condition -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, List import logging from .misc import ensure_is_list @@ -48,7 +48,7 @@ def __init__( self._sql_dialect = sqlglot_dialect self.blocking_rule_sql = blocking_rule_sql - self.preceding_rules = [] + self.preceding_rules: List[BlockingRule] = [] self.sqlglot_dialect = sqlglot_dialect self.salting_partitions = salting_partitions @@ -73,7 +73,7 @@ def and_not_preceding_rules_sql(self): # you filter out any records with nulls in the previous rules # meaning these comparisons get lost or_clauses = [ - f"coalesce(({r.blocking_rule}), false)" for r in self.preceding_rules + f"coalesce(({r.blocking_rule_sql}), false)" for r in self.preceding_rules ] previous_rules = " OR ".join(or_clauses) return f"AND NOT ({previous_rules})" diff --git a/splink/settings_validation/settings_validator.py b/splink/settings_validation/settings_validator.py index a4b84743f8..ca7fe754c0 100644 --- a/splink/settings_validation/settings_validator.py +++ b/splink/settings_validation/settings_validator.py @@ -4,6 +4,7 @@ import re from functools import reduce from operator import and_ +from typing import List import sqlglot @@ -49,7 +50,7 @@ def uid(self): return self.clean_list_of_column_names(uid_as_tree) @property - def blocking_rules(self): + def blocking_rules(self) -> List[str]: brs = self.settings_obj._blocking_rules_to_generate_predictions return [br.blocking_rule_sql for br in brs] From fed4a553c4b5f30346cf9b752f81b328c5fbe5f8 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 6 Nov 2023 16:56:55 +0000 Subject: [PATCH 04/11] lint --- splink/accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/accuracy.py b/splink/accuracy.py index 4512b20492..d8a92a0491 100644 --- a/splink/accuracy.py +++ b/splink/accuracy.py @@ -1,9 +1,9 @@ from copy import deepcopy from typing import TYPE_CHECKING + from .block_from_labels import block_from_labels from .blocking import BlockingRule from .comparison_vector_values import compute_comparison_vector_values_sql - from .predict import predict_from_comparison_vectors_sqls from .sql_transform import move_l_r_table_prefix_to_column_suffix From 087987a8c4678412750e1d21dd6e34a8d2ab10d7 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 6 Nov 2023 17:03:30 +0000 Subject: [PATCH 05/11] improve property and method names --- splink/blocking.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/splink/blocking.py b/splink/blocking.py index 80f4e3d082..6a6928982e 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -64,16 +64,26 @@ def add_preceding_rules(self, rules): rules = ensure_is_list(rules) self.preceding_rules = rules - @property - def and_not_preceding_rules_sql(self): - if not self.preceding_rules: - return "" + def exclude_pairs_generated_by_this_rule_sql(self, linker: Linker): + """A SQL string specifying how to exclude the results + of THIS blocking rule from subseqent blocking statements, + so that subsequent statements do not produce duplicate pairs + """ # Note the coalesce function is important here - otherwise # you filter out any records with nulls in the previous rules # meaning these comparisons get lost + return f"coalesce(({self.blocking_rule_sql}),false)" + + @property + def exclude_pairs_generated_by_all_preceding_rules_sql(self): + """A SQL string that excludes the results of ALL previous blocking rules from + the pairwise comparisons generated. + """ + if not self.preceding_rules: + return "" or_clauses = [ - f"coalesce(({r.blocking_rule_sql}), false)" for r in self.preceding_rules + br.exclude_pairs_generated_by_this_rule_sql() for br in self.preceding_rules ] previous_rules = " OR ".join(or_clauses) return f"AND NOT ({previous_rules})" @@ -319,7 +329,7 @@ def block_using_rules_sqls(linker: Linker): inner join {linker._input_tablename_r} as r on ({salted_br}) - {br.and_not_preceding_rules_sql} + {br.exclude_pairs_generated_by_all_preceding_rules_sql} {where_condition} """ From e1e84952d108c2dc190468e8bc527841b6831b81 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 6 Nov 2023 17:07:17 +0000 Subject: [PATCH 06/11] fix bug --- splink/blocking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/blocking.py b/splink/blocking.py index 6a6928982e..c8dd68aed3 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -64,7 +64,7 @@ def add_preceding_rules(self, rules): rules = ensure_is_list(rules) self.preceding_rules = rules - def exclude_pairs_generated_by_this_rule_sql(self, linker: Linker): + def exclude_pairs_generated_by_this_rule_sql(self): """A SQL string specifying how to exclude the results of THIS blocking rule from subseqent blocking statements, so that subsequent statements do not produce duplicate pairs From e6404f1f109e40d5c449dafea71d2718f72d454e Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 6 Nov 2023 17:13:35 +0000 Subject: [PATCH 07/11] fix ipynb --- docs/demos/tutorials/03_Blocking.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/demos/tutorials/03_Blocking.ipynb b/docs/demos/tutorials/03_Blocking.ipynb index 929a13490a..83e86a4418 100644 --- a/docs/demos/tutorials/03_Blocking.ipynb +++ b/docs/demos/tutorials/03_Blocking.ipynb @@ -153,19 +153,19 @@ "\n", "blocking_rule_1 = block_on([\"substr(first_name, 1,1)\", \"surname\"])\n", "count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_1)\n", - "print(f\"Number of comparisons generated by '{blocking_rule_1.sql}': {count:,.0f}\")\n", + "print(f\"Number of comparisons generated by '{blocking_rule_1.blocking_rule_sql}': {count:,.0f}\")\n", "\n", "blocking_rule_2 = block_on(\"surname\")\n", "count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_2)\n", - "print(f\"Number of comparisons generated by '{blocking_rule_2.sql}': {count:,.0f}\")\n", + "print(f\"Number of comparisons generated by '{blocking_rule_2.blocking_rule_sql}': {count:,.0f}\")\n", "\n", "blocking_rule_3 = block_on(\"email\")\n", "count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_3)\n", - "print(f\"Number of comparisons generated by '{blocking_rule_3.sql}': {count:,.0f}\")\n", + "print(f\"Number of comparisons generated by '{blocking_rule_3.blocking_rule_sql}': {count:,.0f}\")\n", "\n", "blocking_rule_4 = block_on([\"city\", \"first_name\"])\n", "count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_4)\n", - "print(f\"Number of comparisons generated by '{blocking_rule_4.sql}': {count:,.0f}\")\n" + "print(f\"Number of comparisons generated by '{blocking_rule_4.blocking_rule_sql}': {count:,.0f}\")\n" ] }, { From b488167cd895fe3e1c47357a74fa5c2204527dbd Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 6 Nov 2023 17:24:53 +0000 Subject: [PATCH 08/11] fix tests --- tests/test_blocking.py | 4 ++-- tests/test_blocking_rule_composition.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_blocking.py b/tests/test_blocking.py index fd01645275..e84297ddc0 100644 --- a/tests/test_blocking.py +++ b/tests/test_blocking.py @@ -19,7 +19,7 @@ def test_binary_composition_internals_OR(test_helpers, dialect): assert br_surname.__repr__() == exp_txt.format("Exact match", em_rule) assert BlockingRule(em_rule).__repr__() == exp_txt.format("Custom", em_rule) - assert br_surname.blocking_rule == em_rule + assert br_surname.blocking_rule_sql == em_rule assert br_surname.salting_partitions == 4 assert br_surname.preceding_rules == [] @@ -46,7 +46,7 @@ def test_binary_composition_internals_OR(test_helpers, dialect): def assess_preceding_rules(settings_brs_index): br_prec = brs_as_objs[settings_brs_index].preceding_rules - br_prec_txt = [br.blocking_rule for br in br_prec] + br_prec_txt = [br.blocking_rule_sql for br in br_prec] assert br_prec_txt == brs_as_txt[:settings_brs_index] assess_preceding_rules(1) diff --git a/tests/test_blocking_rule_composition.py b/tests/test_blocking_rule_composition.py index fadc9d26bd..3e02c89d5b 100644 --- a/tests/test_blocking_rule_composition.py +++ b/tests/test_blocking_rule_composition.py @@ -11,7 +11,7 @@ def binary_composition_internals(clause, comp_fun, brl, dialect): # Test what happens when only one value is fed # It should just report the regular outputs of our comparison level func level = comp_fun(brl.exact_match_rule("tom")) - assert level.blocking_rule == f"l.{q}tom{q} = r.{q}tom{q}" + assert level.blocking_rule_sql == f"l.{q}tom{q} = r.{q}tom{q}" # Exact match and null level composition level = comp_fun( @@ -19,12 +19,12 @@ def binary_composition_internals(clause, comp_fun, brl, dialect): brl.exact_match_rule("surname"), ) exact_match_sql = f"(l.{q}first_name{q} = r.{q}first_name{q}) {clause} (l.{q}surname{q} = r.{q}surname{q})" # noqa: E501 - assert level.blocking_rule == exact_match_sql + assert level.blocking_rule_sql == exact_match_sql # brl.not_(or_(...)) composition level = brl.not_( comp_fun(brl.exact_match_rule("first_name"), brl.exact_match_rule("surname")), ) - assert level.blocking_rule == f"NOT ({exact_match_sql})" + assert level.blocking_rule_sql == f"NOT ({exact_match_sql})" # Check salting outputs # salting included in the composition function From fe17c8ed6baf9d2765974ede7874b100207bdd3f Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 6 Nov 2023 17:33:52 +0000 Subject: [PATCH 09/11] fix tests --- tests/test_u_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_u_train.py b/tests/test_u_train.py index 453e13beae..274701ebcd 100644 --- a/tests/test_u_train.py +++ b/tests/test_u_train.py @@ -39,7 +39,7 @@ def test_u_train(test_helpers, dialect): assert cl_no.u_probability == (denom - 2) / denom br = linker._settings_obj._blocking_rules_to_generate_predictions[0] - assert br.blocking_rule == "l.name = r.name" + assert br.blocking_rule_sql == "l.name = r.name" @mark_with_dialects_excluding() From 82c0d4073eb6af4afc574d34a4ef4c06a1d5120a Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 6 Nov 2023 17:43:41 +0000 Subject: [PATCH 10/11] fix tests --- splink/linker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/linker.py b/splink/linker.py index 6e20bbd5b5..f8825bc7e0 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -994,7 +994,7 @@ def _populate_probability_two_random_records_match_from_trained_values(self): 15, "\n" f"Probability two random records match from trained model blocking on " - f"{em_training_session._blocking_rule_for_training.blocking_rule}: " + f"{em_training_session._blocking_rule_for_training.blocking_rule_sql}: " f"{training_lambda:,.3f}", ) From c83ea66b96c2bb9bc85bceeadf3190f04f89b96f Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 7 Nov 2023 15:07:07 +0000 Subject: [PATCH 11/11] check blocking_rule_sql is a string --- splink/blocking.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/splink/blocking.py b/splink/blocking.py index c8dd68aed3..303a4c6622 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -40,13 +40,18 @@ def blocking_rule_to_obj(br): class BlockingRule: def __init__( self, - blocking_rule_sql: BlockingRule | dict | str, + blocking_rule_sql: str, salting_partitions=1, sqlglot_dialect: str = None, ): if sqlglot_dialect: self._sql_dialect = sqlglot_dialect + # Temporarily just to see if tests still pass + if not isinstance(blocking_rule_sql, str): + raise ValueError( + f"Blocking rule must be a string, not {type(blocking_rule_sql)}" + ) self.blocking_rule_sql = blocking_rule_sql self.preceding_rules: List[BlockingRule] = [] self.sqlglot_dialect = sqlglot_dialect