From 511a644648697840b5147d26e641cc303d394200 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Fri, 8 Nov 2024 13:53:13 +0000 Subject: [PATCH] works with tf columns --- .../internals/linker_components/inference.py | 30 +++++++++++++--- splink/internals/term_frequencies.py | 36 +++++++++++++------ 2 files changed, 50 insertions(+), 16 deletions(-) diff --git a/splink/internals/linker_components/inference.py b/splink/internals/linker_components/inference.py index c4b69f85b7..4691b6c906 100644 --- a/splink/internals/linker_components/inference.py +++ b/splink/internals/linker_components/inference.py @@ -728,7 +728,9 @@ def compare_two_records( nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf") pipeline.append_input_dataframe(nodes_with_tf) - for tf_col in linker._settings_obj._term_frequency_columns: + tf_cols = linker._settings_obj._term_frequency_columns + + for tf_col in tf_cols: tf_table_name = colname_to_tf_tablename(tf_col) if tf_table_name in cache: tf_table = cache.get_with_logging(tf_table_name) @@ -743,23 +745,41 @@ def compare_two_records( ) sql_join_tf = _join_new_table_to_df_concat_with_tf_sql( - linker, "__splink__compare_two_records_left" + linker, "__splink__compare_two_records_left", df_records_left ) pipeline.enqueue_sql(sql_join_tf, "__splink__compare_two_records_left_with_tf") sql_join_tf = _join_new_table_to_df_concat_with_tf_sql( - linker, "__splink__compare_two_records_right" + linker, "__splink__compare_two_records_right", df_records_right ) pipeline.enqueue_sql(sql_join_tf, "__splink__compare_two_records_right_with_tf") + pipeline = add_unique_id_and_source_dataset_cols_if_needed( + linker, + df_records_left, + pipeline, + in_tablename="__splink__compare_two_records_left_with_tf", + out_tablename="__splink__compare_two_records_left_with_tf_uid_fix", + uid_str="_left", + ) + pipeline = add_unique_id_and_source_dataset_cols_if_needed( + linker, + df_records_right, + pipeline, + in_tablename="__splink__compare_two_records_right_with_tf", + out_tablename="__splink__compare_two_records_right_with_tf_uid_fix", + uid_str="_right", + ) + cols_to_select = self._linker._settings_obj._columns_to_select_for_blocking + select_expr = ", ".join(cols_to_select) sql = f""" select {select_expr}, 0 as match_key - from __splink__compare_two_records_left_with_tf as l - cross join __splink__compare_two_records_right_with_tf as r + from __splink__compare_two_records_left_with_tf_uid_fix as l + cross join __splink__compare_two_records_right_with_tf_uid_fix as r """ pipeline.enqueue_sql(sql, "__splink__compare_two_records_blocked") diff --git a/splink/internals/term_frequencies.py b/splink/internals/term_frequencies.py index 250873e1d4..091d3fcaef 100644 --- a/splink/internals/term_frequencies.py +++ b/splink/internals/term_frequencies.py @@ -4,7 +4,7 @@ # https://github.com/moj-analytical-services/splink/pull/107 import logging import warnings -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from numpy import arange, ceil, floor, log2 from pandas import concat, cut @@ -16,6 +16,7 @@ ) from splink.internals.input_column import InputColumn from splink.internals.pipeline import CTEPipeline +from splink.internals.splink_dataframe import SplinkDataFrame # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports if TYPE_CHECKING: @@ -79,33 +80,46 @@ def _join_tf_to_df_concat_sql(linker: Linker) -> str: return sql -def _join_new_table_to_df_concat_with_tf_sql(linker: Linker, new_tablename: str) -> str: +def _join_new_table_to_df_concat_with_tf_sql( + linker: Linker, + input_tablename: str, + input_table: Optional[SplinkDataFrame] = None, +) -> str: """ - Joins any required tf columns onto new_tablename + Joins any required tf columns onto input_tablename This is needed e.g. when using linker.compare_two_records or linker.inference.find_matches_to_new_records in which the user provides new records which need tf adjustments computed """ + tf_cols_already_populated = [ + c.unquote().name + for c in input_table.columns + if c.unquote().name.startswith("tf_") + ] + tf_cols_not_already_populated = [ + c + for c in linker._settings_obj._term_frequency_columns + if c.unquote().tf_name not in tf_cols_already_populated + ] + cache = linker._intermediate_table_cache - settings_obj = linker._settings_obj - tf_cols = settings_obj._term_frequency_columns - select_cols = [f"{new_tablename}.*"] + select_cols = [f"{input_tablename}.*"] - for col in tf_cols: + for col in tf_cols_not_already_populated: tbl = colname_to_tf_tablename(col) if tbl in cache: select_cols.append(f"{tbl}.{col.tf_name}") - template = "left join {tbl} on " + new_tablename + ".{col} = {tbl}.{col}" + template = "left join {tbl} on " + input_tablename + ".{col} = {tbl}.{col}" template_with_alias = ( - "left join ({subquery}) as {_as} on " + new_tablename + ".{col} = {_as}.{col}" + "left join ({subquery}) as {_as} on " + input_tablename + ".{col} = {_as}.{col}" ) left_joins = [] - for i, col in enumerate(tf_cols): + for i, col in enumerate(tf_cols_not_already_populated): tbl = colname_to_tf_tablename(col) if tbl in cache: sql = template.format(tbl=tbl, col=col.name) @@ -127,7 +141,7 @@ def _join_new_table_to_df_concat_with_tf_sql(linker: Linker, new_tablename: str) sql = f""" select {select_cols_str} - from {new_tablename} + from {input_tablename} {left_joins_str} """