Skip to content

Commit

Permalink
Merge pull request #1718 from moj-analytical-services/validate_distan…
Browse files Browse the repository at this point in the history
…ce_thresholds

Add `distance_threshold` check
  • Loading branch information
ThomasHepworth authored Nov 10, 2023
2 parents 73d3e9e + 73f4ff7 commit 30ad12f
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion splink/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
from .dialects import SplinkDialect


def validate_distance_threshold(
lower_bound: Union[int, float],
upper_bound: Union[int, float],
distance_threshold: Union[int, float],
level_name: str,
) -> Union[int, float]:
"""Check if a distance threshold falls between two bounds."""
if lower_bound <= distance_threshold <= upper_bound:
return distance_threshold
else:
raise ValueError(
"'distance_threshold' must be between "
f"{lower_bound} and {upper_bound} for {level_name}"
)


class NullLevel(ComparisonLevelCreator):
def create_sql(self, sql_dialect: SplinkDialect) -> str:
col = self.input_column(sql_dialect)
Expand Down Expand Up @@ -83,8 +99,14 @@ def __init__(self, col_name: str, distance_threshold: Union[int, float]):
distance_threshold (Union[int, float]): The threshold to use to assess
similarity
"""

super().__init__(col_name)
self.distance_threshold = distance_threshold
self.distance_threshold = validate_distance_threshold(
lower_bound=0,
upper_bound=1,
distance_threshold=distance_threshold,
level_name=self.__class__.__name__,
)

def create_sql(self, sql_dialect: SplinkDialect) -> str:
col_l, col_r = self.input_column(sql_dialect).names_l_r()
Expand Down

0 comments on commit 30ad12f

Please sign in to comment.