Skip to content

Commit

Permalink
Merge pull request #1860 from moj-analytical-services/custom-distance…
Browse files Browse the repository at this point in the history
…-function

Custom distance function level + comparison at thresholds
  • Loading branch information
ADBond authored Jan 22, 2024
2 parents 156efff + 2f3aeee commit 9f63045
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 0 deletions.
50 changes: 50 additions & 0 deletions splink/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,56 @@ def create_label_for_charts(self) -> str:
return f"Jaccard distance of '{col.label} >= {self.distance_threshold}'"


class DistanceFunctionLevel(ComparisonLevelCreator):
def __init__(
self,
col_name: Union[str, ColumnExpression],
distance_function_name: str,
distance_threshold: Union[int, float],
higher_is_more_similar: bool = True,
):
"""A comparison level using an arbitrary distance function
e.g. `custom_distance(val_l, val_r) >= (<=) distance_threshold`
The function given by `distance_function_name` must exist in the SQL
backend you use, and must take two parameters of the type in `col_name,
returning a numeric type
Args:
col_name (str | ColumnExpression): Input column name
distance_function_name (str): the name of the SQL distance function
distance_threshold (Union[int, float]): The threshold to use to assess
similarity
higher_is_more_similar (bool): Are higher values of the distance function
more similar? (e.g. True for Jaro-Winkler, False for Levenshtein)
Default is True
"""

self.col_expression = ColumnExpression.instantiate_if_str(col_name)
self.distance_function_name = distance_function_name
self.distance_threshold = distance_threshold
self.higher_is_more_similar = higher_is_more_similar

def create_sql(self, sql_dialect: SplinkDialect) -> str:
self.col_expression.sql_dialect = sql_dialect
col = self.col_expression
d_fn = self.distance_function_name
less_or_greater_than = ">" if self.higher_is_more_similar else "<"
return (
f"{d_fn}({col.name_l}, {col.name_r}) "
f"{less_or_greater_than}= {self.distance_threshold}"
)

def create_label_for_charts(self) -> str:
col = self.col_expression
less_or_greater = "greater" if self.higher_is_more_similar else "less"
return (
f"`{self.distance_function_name}` distance of '{col.label} "
f"{less_or_greater} than {self.distance_threshold}'"
)


class DatediffLevel(ComparisonLevelCreator):
def __init__(
self,
Expand Down
68 changes: 68 additions & 0 deletions splink/comparison_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,74 @@ def create_output_column_name(self) -> str:
return self.col_expression.output_column_name


class DistanceFunctionAtThresholds(ComparisonCreator):
def __init__(
self,
col_name: str,
distance_function_name,
distance_threshold_or_thresholds: Union[Iterable[float], float],
higher_is_more_similar: bool = True,
):
"""
Represents a comparison of the data in `col_name` with three or more levels:
- Exact match in `col_name`
- Custom distance function levels at specified thresholds
- ...
- Anything else
For example, with distance_threshold_or_thresholds = [1, 3]
and distance_function 'hamming', with higher_is_more_similar False
the levels are:
- Exact match in `col_name`
- Hamming distance of `col_name` <= 1
- Hamming distance of `col_name` <= 3
- Anything else
Args:
col_name (str): The name of the column to compare.
distance_function_name (str): the name of the SQL distance function
distance_threshold_or_thresholds (Union[float, list], optional): The
threshold(s) to use for the distance function level(s).
higher_is_more_similar (bool): Are higher values of the distance function
more similar? (e.g. True for Jaro-Winkler, False for Levenshtein)
Default is True
"""

thresholds_as_iterable = ensure_is_iterable(distance_threshold_or_thresholds)
self.thresholds = [*thresholds_as_iterable]
self.distance_function_name = distance_function_name
self.higher_is_more_similar = higher_is_more_similar
super().__init__(col_name)

def create_comparison_levels(self) -> List[ComparisonLevelCreator]:
return [
cll.NullLevel(self.col_expression),
cll.ExactMatchLevel(self.col_expression),
*[
cll.DistanceFunctionLevel(
self.col_expression,
self.distance_function_name,
threshold,
higher_is_more_similar=self.higher_is_more_similar,
)
for threshold in self.thresholds
],
cll.ElseLevel(),
]

def create_description(self) -> str:
comma_separated_thresholds_string = ", ".join(map(str, self.thresholds))
return (
f"Exact match '{self.col_expression.label}' vs. "
f"`{self.distance_function_name}` at thresholds "
f"{comma_separated_thresholds_string} vs. "
"anything else"
)

def create_output_column_name(self) -> str:
return self.col_expression.output_column_name


class DatediffAtThresholds(ComparisonCreator):
def __init__(
self,
Expand Down

0 comments on commit 9f63045

Please sign in to comment.