Skip to content

Commit

Permalink
allow found by blocking rules
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Nov 11, 2024
1 parent 1650a2b commit a495811
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 10 deletions.
7 changes: 4 additions & 3 deletions splink/internals/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

if TYPE_CHECKING:
from splink.internals.linker import Linker
from splink.internals.settings import Settings


def truth_space_table_from_labels_with_predictions_sqls(
Expand Down Expand Up @@ -289,8 +290,8 @@ def truth_space_table_from_labels_with_predictions_sqls(
return sqls


def _select_found_by_blocking_rules(linker: "Linker") -> str:
brs = linker._settings_obj._blocking_rules_to_generate_predictions
def _select_found_by_blocking_rules(settings_obj: "Settings") -> str:
brs = settings_obj._blocking_rules_to_generate_predictions

if brs:
br_strings = [
Expand Down Expand Up @@ -425,7 +426,7 @@ def predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename):
)

sqls.extend(sqls_2)
br_col = _select_found_by_blocking_rules(linker)
br_col = _select_found_by_blocking_rules(linker._settings_obj)

sql = f"""
select *, {br_col}
Expand Down
15 changes: 14 additions & 1 deletion splink/internals/linker_components/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from typing import TYPE_CHECKING, Any

from splink.internals.accuracy import _select_found_by_blocking_rules
from splink.internals.blocking import (
BlockingRule,
block_using_rules_sqls,
Expand Down Expand Up @@ -639,7 +640,10 @@ def find_matches_to_new_records(
return predictions

def compare_two_records(
self, record_1: dict[str, Any], record_2: dict[str, Any]
self,
record_1: dict[str, Any],
record_2: dict[str, Any],
include_found_by_blocking_rules: bool = False,
) -> SplinkDataFrame:
"""Use the linkage model to compare and score a pairwise record comparison
based on the two input records provided
Expand Down Expand Up @@ -799,6 +803,15 @@ def compare_two_records(
)
pipeline.enqueue_list_of_sqls(sqls)

if include_found_by_blocking_rules:
br_col = _select_found_by_blocking_rules(linker._settings_obj)
sql = f"""
select *, {br_col}
from __splink__df_predict
"""

pipeline.enqueue_sql(sql, "__splink__found_by_blocking_rules")

predictions = linker._db_api.sql_pipeline_to_splink_dataframe(
pipeline, use_cache=False
)
Expand Down
35 changes: 29 additions & 6 deletions splink/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from pathlib import Path
from typing import Any, Dict

from .internals.database_api import DatabaseAPISubClass
from .internals.misc import ascii_uid
from .internals.pipeline import CTEPipeline
from .internals.predict import (
from splink.internals.accuracy import _select_found_by_blocking_rules
from splink.internals.database_api import DatabaseAPISubClass
from splink.internals.misc import ascii_uid
from splink.internals.pipeline import CTEPipeline
from splink.internals.predict import (
predict_from_comparison_vectors_sqls_using_settings,
)
from .internals.settings_creator import SettingsCreator
from .internals.splink_dataframe import SplinkDataFrame
from splink.internals.settings_creator import SettingsCreator
from splink.internals.splink_dataframe import SplinkDataFrame

__all__ = [
"compare_records",
Expand All @@ -26,6 +27,7 @@ def compare_records(
settings: SettingsCreator | dict[str, Any] | Path | str,
db_api: DatabaseAPISubClass,
use_sql_from_cache: bool = True,
include_found_by_blocking_rules: bool = False,
) -> SplinkDataFrame:
"""Compare two records and compute similarity scores without requiring a Linker.
Assumes any required term frequency values are provided in the input records.
Expand Down Expand Up @@ -83,6 +85,13 @@ def compare_records(

settings_obj = settings_creator.get_settings(db_api.sql_dialect.sql_dialect_str)

retain_matching_columns = settings_obj._retain_matching_columns
retain_intermediate_calculation_columns = (
settings_obj._retain_intermediate_calculation_columns
)
settings_obj._retain_matching_columns = True
settings_obj._retain_intermediate_calculation_columns = True

pipeline = CTEPipeline([df_records_left, df_records_right])

cols_to_select = settings_obj._columns_to_select_for_blocking
Expand All @@ -109,9 +118,23 @@ def compare_records(
)
pipeline.enqueue_list_of_sqls(sqls)

if include_found_by_blocking_rules:
br_col = _select_found_by_blocking_rules(settings_obj)
sql = f"""
select *, {br_col}
from __splink__df_predict
"""

pipeline.enqueue_sql(sql, "__splink__found_by_blocking_rules")

predictions = db_api.sql_pipeline_to_splink_dataframe(pipeline)

_sql_used_for_compare_records_cache["sql"] = predictions.sql_used_to_create
_sql_used_for_compare_records_cache["uid"] = uid

settings_obj._retain_matching_columns = retain_matching_columns
settings_obj._retain_intermediate_calculation_columns = (
retain_intermediate_calculation_columns
)

return predictions

0 comments on commit a495811

Please sign in to comment.