Skip to content

Commit

Permalink
works with tf columns
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Nov 8, 2024
1 parent ee31efd commit 511a644
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 16 deletions.
30 changes: 25 additions & 5 deletions splink/internals/linker_components/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,9 @@ def compare_two_records(
nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf")
pipeline.append_input_dataframe(nodes_with_tf)

for tf_col in linker._settings_obj._term_frequency_columns:
tf_cols = linker._settings_obj._term_frequency_columns

for tf_col in tf_cols:
tf_table_name = colname_to_tf_tablename(tf_col)
if tf_table_name in cache:
tf_table = cache.get_with_logging(tf_table_name)
Expand All @@ -743,23 +745,41 @@ def compare_two_records(
)

sql_join_tf = _join_new_table_to_df_concat_with_tf_sql(
linker, "__splink__compare_two_records_left"
linker, "__splink__compare_two_records_left", df_records_left
)

pipeline.enqueue_sql(sql_join_tf, "__splink__compare_two_records_left_with_tf")

sql_join_tf = _join_new_table_to_df_concat_with_tf_sql(
linker, "__splink__compare_two_records_right"
linker, "__splink__compare_two_records_right", df_records_right
)

pipeline.enqueue_sql(sql_join_tf, "__splink__compare_two_records_right_with_tf")

pipeline = add_unique_id_and_source_dataset_cols_if_needed(
linker,
df_records_left,
pipeline,
in_tablename="__splink__compare_two_records_left_with_tf",
out_tablename="__splink__compare_two_records_left_with_tf_uid_fix",
uid_str="_left",
)
pipeline = add_unique_id_and_source_dataset_cols_if_needed(
linker,
df_records_right,
pipeline,
in_tablename="__splink__compare_two_records_right_with_tf",
out_tablename="__splink__compare_two_records_right_with_tf_uid_fix",
uid_str="_right",
)

cols_to_select = self._linker._settings_obj._columns_to_select_for_blocking

select_expr = ", ".join(cols_to_select)
sql = f"""
select {select_expr}, 0 as match_key
from __splink__compare_two_records_left_with_tf as l
cross join __splink__compare_two_records_right_with_tf as r
from __splink__compare_two_records_left_with_tf_uid_fix as l
cross join __splink__compare_two_records_right_with_tf_uid_fix as r
"""
pipeline.enqueue_sql(sql, "__splink__compare_two_records_blocked")

Expand Down
36 changes: 25 additions & 11 deletions splink/internals/term_frequencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# https://github.com/moj-analytical-services/splink/pull/107
import logging
import warnings
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

from numpy import arange, ceil, floor, log2
from pandas import concat, cut
Expand All @@ -16,6 +16,7 @@
)
from splink.internals.input_column import InputColumn
from splink.internals.pipeline import CTEPipeline
from splink.internals.splink_dataframe import SplinkDataFrame

# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
if TYPE_CHECKING:
Expand Down Expand Up @@ -79,33 +80,46 @@ def _join_tf_to_df_concat_sql(linker: Linker) -> str:
return sql


def _join_new_table_to_df_concat_with_tf_sql(linker: Linker, new_tablename: str) -> str:
def _join_new_table_to_df_concat_with_tf_sql(
linker: Linker,
input_tablename: str,
input_table: Optional[SplinkDataFrame] = None,
) -> str:
"""
Joins any required tf columns onto new_tablename
Joins any required tf columns onto input_tablename
This is needed e.g. when using linker.compare_two_records
or linker.inference.find_matches_to_new_records in which the user provides
new records which need tf adjustments computed
"""

tf_cols_already_populated = [
c.unquote().name
for c in input_table.columns
if c.unquote().name.startswith("tf_")
]
tf_cols_not_already_populated = [
c
for c in linker._settings_obj._term_frequency_columns
if c.unquote().tf_name not in tf_cols_already_populated
]

cache = linker._intermediate_table_cache
settings_obj = linker._settings_obj
tf_cols = settings_obj._term_frequency_columns

select_cols = [f"{new_tablename}.*"]
select_cols = [f"{input_tablename}.*"]

for col in tf_cols:
for col in tf_cols_not_already_populated:
tbl = colname_to_tf_tablename(col)
if tbl in cache:
select_cols.append(f"{tbl}.{col.tf_name}")

template = "left join {tbl} on " + new_tablename + ".{col} = {tbl}.{col}"
template = "left join {tbl} on " + input_tablename + ".{col} = {tbl}.{col}"
template_with_alias = (
"left join ({subquery}) as {_as} on " + new_tablename + ".{col} = {_as}.{col}"
"left join ({subquery}) as {_as} on " + input_tablename + ".{col} = {_as}.{col}"
)

left_joins = []
for i, col in enumerate(tf_cols):
for i, col in enumerate(tf_cols_not_already_populated):
tbl = colname_to_tf_tablename(col)
if tbl in cache:
sql = template.format(tbl=tbl, col=col.name)
Expand All @@ -127,7 +141,7 @@ def _join_new_table_to_df_concat_with_tf_sql(linker: Linker, new_tablename: str)

sql = f"""
select {select_cols_str}
from {new_tablename}
from {input_tablename}
{left_joins_str}
"""
Expand Down

0 comments on commit 511a644

Please sign in to comment.