diff --git a/splink/internals/accuracy.py b/splink/internals/accuracy.py index b7d342b5fa..7ec766b5cf 100644 --- a/splink/internals/accuracy.py +++ b/splink/internals/accuracy.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from splink.internals.linker import Linker + from splink.internals.settings import Settings def truth_space_table_from_labels_with_predictions_sqls( @@ -289,8 +290,8 @@ def truth_space_table_from_labels_with_predictions_sqls( return sqls -def _select_found_by_blocking_rules(linker: "Linker") -> str: - brs = linker._settings_obj._blocking_rules_to_generate_predictions +def _select_found_by_blocking_rules(settings_obj: "Settings") -> str: + brs = settings_obj._blocking_rules_to_generate_predictions if brs: br_strings = [ @@ -425,7 +426,7 @@ def predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename): ) sqls.extend(sqls_2) - br_col = _select_found_by_blocking_rules(linker) + br_col = _select_found_by_blocking_rules(linker._settings_obj) sql = f""" select *, {br_col} diff --git a/splink/internals/linker_components/inference.py b/splink/internals/linker_components/inference.py index 4691b6c906..1556ede0d9 100644 --- a/splink/internals/linker_components/inference.py +++ b/splink/internals/linker_components/inference.py @@ -4,6 +4,7 @@ import time from typing import TYPE_CHECKING, Any +from splink.internals.accuracy import _select_found_by_blocking_rules from splink.internals.blocking import ( BlockingRule, block_using_rules_sqls, @@ -639,7 +640,10 @@ def find_matches_to_new_records( return predictions def compare_two_records( - self, record_1: dict[str, Any], record_2: dict[str, Any] + self, + record_1: dict[str, Any], + record_2: dict[str, Any], + include_found_by_blocking_rules: bool = False, ) -> SplinkDataFrame: """Use the linkage model to compare and score a pairwise record comparison based on the two input records provided @@ -799,6 +803,15 @@ def compare_two_records( ) pipeline.enqueue_list_of_sqls(sqls) + if include_found_by_blocking_rules: + br_col = _select_found_by_blocking_rules(linker._settings_obj) + sql = f""" + select *, {br_col} + from __splink__df_predict + """ + + pipeline.enqueue_sql(sql, "__splink__found_by_blocking_rules") + predictions = linker._db_api.sql_pipeline_to_splink_dataframe( pipeline, use_cache=False ) diff --git a/splink/realtime.py b/splink/realtime.py index 10faab0760..025e25c948 100644 --- a/splink/realtime.py +++ b/splink/realtime.py @@ -3,14 +3,15 @@ from pathlib import Path from typing import Any, Dict -from .internals.database_api import DatabaseAPISubClass -from .internals.misc import ascii_uid -from .internals.pipeline import CTEPipeline -from .internals.predict import ( +from splink.internals.accuracy import _select_found_by_blocking_rules +from splink.internals.database_api import DatabaseAPISubClass +from splink.internals.misc import ascii_uid +from splink.internals.pipeline import CTEPipeline +from splink.internals.predict import ( predict_from_comparison_vectors_sqls_using_settings, ) -from .internals.settings_creator import SettingsCreator -from .internals.splink_dataframe import SplinkDataFrame +from splink.internals.settings_creator import SettingsCreator +from splink.internals.splink_dataframe import SplinkDataFrame __all__ = [ "compare_records", @@ -26,6 +27,7 @@ def compare_records( settings: SettingsCreator | dict[str, Any] | Path | str, db_api: DatabaseAPISubClass, use_sql_from_cache: bool = True, + include_found_by_blocking_rules: bool = False, ) -> SplinkDataFrame: """Compare two records and compute similarity scores without requiring a Linker. Assumes any required term frequency values are provided in the input records. @@ -83,6 +85,13 @@ def compare_records( settings_obj = settings_creator.get_settings(db_api.sql_dialect.sql_dialect_str) + retain_matching_columns = settings_obj._retain_matching_columns + retain_intermediate_calculation_columns = ( + settings_obj._retain_intermediate_calculation_columns + ) + settings_obj._retain_matching_columns = True + settings_obj._retain_intermediate_calculation_columns = True + pipeline = CTEPipeline([df_records_left, df_records_right]) cols_to_select = settings_obj._columns_to_select_for_blocking @@ -109,9 +118,23 @@ def compare_records( ) pipeline.enqueue_list_of_sqls(sqls) + if include_found_by_blocking_rules: + br_col = _select_found_by_blocking_rules(settings_obj) + sql = f""" + select *, {br_col} + from __splink__df_predict + """ + + pipeline.enqueue_sql(sql, "__splink__found_by_blocking_rules") + predictions = db_api.sql_pipeline_to_splink_dataframe(pipeline) _sql_used_for_compare_records_cache["sql"] = predictions.sql_used_to_create _sql_used_for_compare_records_cache["uid"] = uid + settings_obj._retain_matching_columns = retain_matching_columns + settings_obj._retain_intermediate_calculation_columns = ( + retain_intermediate_calculation_columns + ) + return predictions