diff --git a/splink/em_training_session.py b/splink/em_training_session.py index 8787ebc45e..2e1694877f 100644 --- a/splink/em_training_session.py +++ b/splink/em_training_session.py @@ -183,7 +183,11 @@ def _comparison_vectors(self): pipeline = CTEPipeline([nodes_with_tf]) sqls = block_using_rules_sqls( - self._original_linker, [self._blocking_rule_for_training] + self._original_linker, + input_tablename_l="__splink__df_concat_with_tf", + input_tablename_r="__splink__df_concat_with_tf", + blocking_rules=[self._blocking_rule_for_training], + link_type=self._original_linker._settings_obj._link_type, ) pipeline.enqueue_list_of_sqls(sqls) diff --git a/splink/linker.py b/splink/linker.py index ee4e8fe4d5..d8940735a7 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -246,8 +246,6 @@ def __init__( self._validate_settings(validate_settings) self._em_training_sessions: list[EMTrainingSession] = [] - self._find_new_matches_mode = False - self._self_link_mode = False self._deterministic_link_mode = False @@ -333,8 +331,6 @@ def _cache_uid(self, value): @property def _input_tablename_l(self): - if self._find_new_matches_mode: - return "__splink__df_concat_with_tf" if self._self_link_mode: return "__splink__df_concat_with_tf" @@ -352,8 +348,6 @@ def _input_tablename_l(self): @property def _input_tablename_r(self): - if self._find_new_matches_mode: - return "__splink__df_new_records_with_tf" if self._self_link_mode: return "__splink__df_concat_with_tf" @@ -1269,8 +1263,6 @@ def find_matches_to_new_records( self._settings_obj._blocking_rules_to_generate_predictions = blocking_rules - self._find_new_matches_mode = True - for tf_col in self._settings_obj._term_frequency_columns: tf_table = colname_to_tf_tablename(tf_col) if tf_table in self._intermediate_table_cache: @@ -1319,7 +1311,6 @@ def find_matches_to_new_records( original_blocking_rules ) self._settings_obj._link_type = original_link_type - self._find_new_matches_mode = False return predictions