Skip to content

Commit

Permalink
Merge pull request #1714 from moj-analytical-services/migrate-tests
Browse files Browse the repository at this point in the history
Migrate tests for Splink 4 (`ComparisonLevelCreator` and `ComparisonCreator` and related changes)
  • Loading branch information
ADBond authored Jan 22, 2024
2 parents fe45feb + f0c63d9 commit 42b593d
Show file tree
Hide file tree
Showing 42 changed files with 1,030 additions and 1,097 deletions.
4 changes: 4 additions & 0 deletions splink/column_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,7 @@ def label(self) -> str:
return "transformed " + self.raw_sql_expression
else:
return self.raw_sql_expression

def __repr__(self):
# TODO: need to include transform info, but guard for case of no dialect
return f"ColumnExpression(sql_expression='{self.raw_sql_expression}')"
8 changes: 4 additions & 4 deletions splink/comparison_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def configure(
self,
*,
term_frequency_adjustments: bool = False,
m_probabilities: list[float] = None,
u_probabilities: list[float] = None,
m_probabilities: List[float] = None,
u_probabilities: List[float] = None,
) -> "ComparisonCreator":
"""
Configure the comparison creator with m and u probabilities. The first
Expand Down Expand Up @@ -185,7 +185,7 @@ def m_probabilities(self):

@final
@m_probabilities.setter
def m_probabilities(self, m_probabilities: list[float]):
def m_probabilities(self, m_probabilities: List[float]):
if m_probabilities:
num_probs_supplied = len(m_probabilities)
num_non_null_levels = self.num_non_null_levels
Expand All @@ -204,7 +204,7 @@ def u_probabilities(self):

@final
@u_probabilities.setter
def u_probabilities(self, u_probabilities: list[float]):
def u_probabilities(self, u_probabilities: List[float]):
if u_probabilities:
num_probs_supplied = len(u_probabilities)
num_non_null_levels = self.num_non_null_levels
Expand Down
6 changes: 3 additions & 3 deletions splink/comparison_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class CustomComparison(ComparisonCreator):
def __init__(
self,
output_column_name: str,
comparison_levels: list[Union[ComparisonLevelCreator, dict]],
comparison_levels: List[Union[ComparisonLevelCreator, dict]],
description: str = None,
):
"""
Expand Down Expand Up @@ -437,8 +437,8 @@ def __init__(
self,
col_name: str,
*,
date_metrics: Union[str, list[str]],
date_thresholds: Union[int, list[int]],
date_metrics: Union[str, List[str]],
date_thresholds: Union[int, List[int]],
cast_strings_to_dates: bool = False,
date_format: str = None,
term_frequency_adjustments=False,
Expand Down
14 changes: 14 additions & 0 deletions splink/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@ def _try_parse_date_raw(self, name: str, date_format: str = None):
date_format = self.default_date_format
return f"""try_strptime({name}, '{date_format}')"""

# TODO: this is only needed for duckdb < 0.9.0.
# should we just ditch support for that? (only for cll - engine should still work)
def array_intersect(self, clc: "ComparisonLevelCreator"):
clc.col_expression.sql_dialect = self
col = clc.col_expression
threshold = clc.min_intersection

# sum of individual (unique) array sizes, minus the (unique) union
return (
f"list_unique({col.name_l}) + list_unique({col.name_r})"
f" - list_unique(list_concat({col.name_l}, {col.name_r}))"
f" >= {threshold}"
).strip()

def _regex_extract_raw(self, name: str, pattern: str, capture_group: int = 0):
return f"regexp_extract({name}, '{pattern}', {capture_group})"

Expand Down
4 changes: 4 additions & 0 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,8 @@ def _instantiate_comparison_levels(self, settings_dict):
instances are instead replaced with ComparisonLevels
"""
dialect = self._sql_dialect
if settings_dict is None:
return
if "comparisons" not in settings_dict:
return
comparisons = settings_dict["comparisons"]
Expand Down Expand Up @@ -1184,6 +1186,8 @@ def load_settings(
settings_dict["sql_dialect"] = settings_dict.get(
"sql_dialect", self._sql_dialect
)

self._instantiate_comparison_levels(settings_dict)
self._settings_dict = settings_dict
self._settings_obj_ = Settings(settings_dict)
self._validate_input_dfs()
Expand Down
24 changes: 12 additions & 12 deletions tests/basic_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,23 @@ def name_comparison(cll, sn: str) -> dict:
"output_column_name": "first_name_and_surname",
"comparison_levels": [
# Null level
cll.or_(cll.null_level("first_name"), cll.null_level(sn)),
cll.Or(cll.NullLevel("first_name"), cll.NullLevel(sn)),
# Exact match on fn and sn
cll.or_(
cll.exact_match_level("first_name"),
cll.exact_match_level(sn),
cll.Or(
cll.ExactMatchLevel("first_name"),
cll.ExactMatchLevel(sn),
).configure(
m_probability=0.8,
label_for_charts="Exact match on first name or surname",
),
# (Levenshtein(fn) and jaro_winkler(fn)) or levenshtein(sur)
cll.and_(
cll.or_(
cll.levenshtein_level("first_name", 2),
cll.jaro_winkler_level("first_name", 0.8),
m_probability=0.8,
),
cll.levenshtein_level(sn, 3),
cll.And(
cll.Or(
cll.LevenshteinLevel("first_name", 2),
cll.JaroWinklerLevel("first_name", 0.8),
).configure(m_probability=0.8),
cll.LevenshteinLevel(sn, 3),
),
cll.else_level(0.1),
cll.ElseLevel().configure(m_probability=0.1),
],
}
90 changes: 20 additions & 70 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,10 @@
TEXT,
)

# import splink.duckdb.blocking_rule_library as brl_duckdb
# import splink.duckdb.comparison_library as cl_duckdb
# import splink.duckdb.comparison_template_library as ctl_duckdb
# import splink.postgres.blocking_rule_library as brl_postgres
# import splink.postgres.comparison_library as cl_postgres
# import splink.postgres.comparison_template_library as ctl_postgres
# import splink.spark.blocking_rule_library as brl_spark
# import splink.spark.comparison_library as cl_spark
# import splink.spark.comparison_template_library as ctl_spark
# import splink.sqlite.blocking_rule_library as brl_sqlite
# import splink.sqlite.comparison_library as cl_sqlite
# import splink.sqlite.comparison_template_library as ctl_sqlite
import splink.duckdb.blocking_rule_library as brl_duckdb
import splink.postgres.blocking_rule_library as brl_postgres
import splink.spark.blocking_rule_library as brl_spark
import splink.sqlite.blocking_rule_library as brl_sqlite
from splink.duckdb.linker import DuckDBLinker
from splink.postgres.linker import PostgresLinker
from splink.spark.linker import SparkLinker
Expand Down Expand Up @@ -50,20 +42,10 @@ def load_frame_from_csv(self, path):
def load_frame_from_parquet(self, path):
return pd.read_parquet(path)

# @property
# @abstractmethod
# def cl(self):
# pass

# @property
# @abstractmethod
# def ctl(self):
# pass

# @property
# @abstractmethod
# def brl(self):
# pass
@property
@abstractmethod
def brl(self):
pass


class DuckDBTestHelper(TestHelper):
Expand All @@ -78,17 +60,9 @@ def convert_frame(self, df):
def date_format(self):
return "%Y-%m-%d"

# @property
# def cl(self):
# return cl_duckdb

# @property
# def ctl(self):
# return ctl_duckdb

# @property
# def brl(self):
# return brl_duckdb
@property
def brl(self):
return brl_duckdb


class SparkTestHelper(TestHelper):
Expand Down Expand Up @@ -117,17 +91,9 @@ def load_frame_from_parquet(self, path):
df.persist()
return df

# @property
# def cl(self):
# return cl_spark

# @property
# def ctl(self):
# return ctl_spark

# @property
# def brl(self):
# return brl_spark
@property
def brl(self):
return brl_spark


class SQLiteTestHelper(TestHelper):
Expand Down Expand Up @@ -161,17 +127,9 @@ def load_frame_from_csv(self, path):
def load_frame_from_parquet(self, path):
return self.convert_frame(super().load_frame_from_parquet(path))

# @property
# def cl(self):
# return cl_sqlite

# @property
# def ctl(self):
# return ctl_sqlite

# @property
# def brl(self):
# return brl_sqlite
@property
def brl(self):
return brl_sqlite


class PostgresTestHelper(TestHelper):
Expand Down Expand Up @@ -218,17 +176,9 @@ def load_frame_from_csv(self, path):
def load_frame_from_parquet(self, path):
return self.convert_frame(super().load_frame_from_parquet(path))

# @property
# def cl(self):
# return cl_postgres

# @property
# def ctl(self):
# return ctl_postgres

# @property
# def brl(self):
# return brl_postgres
@property
def brl(self):
return brl_postgres


class SplinkTestException(Exception):
Expand Down
20 changes: 10 additions & 10 deletions tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
predictions_from_sample_of_pairwise_labels_sql,
truth_space_table_from_labels_with_predictions_sqls,
)
from splink.duckdb.blocking_rule_library import exact_match_rule
from splink.duckdb.comparison_library import exact_match
from splink.comparison_library import ExactMatch
from splink.duckdb.blocking_rule_library import block_on
from splink.duckdb.linker import DuckDBLinker

from .basic_settings import get_settings_dict
Expand All @@ -33,13 +33,13 @@ def test_scored_labels_table():
settings = {
"link_type": "dedupe_only",
"comparisons": [
exact_match("first_name"),
exact_match("surname"),
exact_match("dob"),
ExactMatch("first_name"),
ExactMatch("surname"),
ExactMatch("dob"),
],
"blocking_rules_to_generate_predictions": [
"l.surname = r.surname",
exact_match_rule("dob"),
block_on("dob"),
],
}

Expand Down Expand Up @@ -92,13 +92,13 @@ def test_truth_space_table():
settings = {
"link_type": "dedupe_only",
"comparisons": [
exact_match("first_name"),
exact_match("surname"),
exact_match("dob"),
ExactMatch("first_name"),
ExactMatch("surname"),
ExactMatch("dob"),
],
"blocking_rules_to_generate_predictions": [
"l.surname = r.surname",
exact_match_rule("dob"),
block_on("dob"),
],
}

Expand Down
11 changes: 5 additions & 6 deletions tests/test_analyse_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_blocking_records_accuracy(test_helpers, dialect):
)

blocking_rules = [
brl.exact_match_rule("first_name"),
brl.block_on("first_name"),
brl.block_on(["first_name", "surname"]),
"l.dob = r.dob",
]
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_blocking_records_accuracy(test_helpers, dialect):
blocking_rules = [
"l.surname = r.surname", # 2l:2r,
brl.or_(
brl.exact_match_rule("first_name"),
brl.block_on("first_name"),
"substr(l.dob,1,4) = substr(r.dob,1,4)",
), # 1r:1r, 1l:2l, 1l:2r
"l.surname = r.surname",
Expand Down Expand Up @@ -438,18 +438,17 @@ def test_cumulative_br_funs(test_helpers, dialect):
linker.cumulative_comparisons_from_blocking_rules_records(
[
"l.first_name = r.first_name",
brl.exact_match_rule("surname"),
brl.block_on("surname"),
]
)

linker.cumulative_num_comparisons_from_blocking_rules_chart(
[
"l.first_name = r.first_name",
brl.exact_match_rule("surname"),
brl.block_on("surname"),
]
)

assert (
linker.count_num_comparisons_from_blocking_rule(brl.exact_match_rule("surname"))
== 3167
linker.count_num_comparisons_from_blocking_rule(brl.block_on("surname")) == 3167
)
13 changes: 7 additions & 6 deletions tests/test_array_based_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pandas as pd

import splink.comparison_library as cl
from tests.decorator import mark_with_dialects_including


Expand Down Expand Up @@ -31,7 +32,7 @@ def test_simple_example_link_only(test_helpers, dialect):
},
"l.gender = r.gender",
],
"comparisons": [helper.cl.array_intersect_at_sizes("postcode", [1])],
"comparisons": [cl.ArrayIntersectAtSizes("postcode", [1])],
}
## the pairs returned by the first blocking rule are (1,6),(2,4),(2,6)
## the additional pairs returned by the second blocking rule are (1,4),(3,5)
Expand Down Expand Up @@ -105,7 +106,7 @@ def test_array_based_blocking_with_random_data_dedupe(test_helpers, dialect):
"blocking_rules_to_generate_predictions": blocking_rules,
"unique_id_column_name": "unique_id",
"additional_columns_to_retain": ["cluster"],
"comparisons": [helper.cl.array_intersect_at_sizes("array_column_1", [1])],
"comparisons": [cl.ArrayIntersectAtSizes("array_column_1", [1])],
}
linker = helper.Linker(input_data, settings, **helper.extra_linker_args())
linker.debug_mode = False
Expand Down Expand Up @@ -152,7 +153,7 @@ def test_array_based_blocking_with_random_data_link_only(test_helpers, dialect):
"blocking_rules_to_generate_predictions": blocking_rules,
"unique_id_column_name": "cluster",
"additional_columns_to_retain": ["cluster"],
"comparisons": [helper.cl.array_intersect_at_sizes("array_column_1", [1])],
"comparisons": [cl.ArrayIntersectAtSizes("array_column_1", [1])],
}
linker = helper.Linker(
[input_data_l, input_data_r], settings, **helper.extra_linker_args()
Expand Down Expand Up @@ -218,9 +219,9 @@ def test_link_only_unique_id_ambiguity(test_helpers, dialect):
"l.surname = r.surname",
],
"comparisons": [
helper.cl.exact_match("first_name"),
helper.cl.exact_match("surname"),
helper.cl.exact_match("postcode"),
cl.ExactMatch("first_name"),
cl.ExactMatch("surname"),
cl.ExactMatch("postcode"),
],
"retain_intermediate_calculation_columns": True,
}
Expand Down
Loading

0 comments on commit 42b593d

Please sign in to comment.