Skip to content

Commit

Permalink
Merge pull request #1861 from moj-analytical-services/comparison-vali…
Browse files Browse the repository at this point in the history
…dation

Comparison validation
ADBond authored Jan 22, 2024
2 parents 9f63045 + f6c449d commit fe45feb
Showing 4 changed files with 82 additions and 15 deletions.
5 changes: 5 additions & 0 deletions splink/comparison_creator.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@ def __init__(
name_reference: ColumnExpression.instantiate_if_str(column)
for name_reference, column in cols.items()
}
self._validate()

# many ComparisonCreators have a single column expression, so provide a
# convenience property for this case. Error if there are none or many
@@ -57,6 +58,10 @@ def col_expression(self) -> ColumnExpression:
) from None
return col_expression

def _validate(self) -> None:
# create levels - let them raise errors if there are issues
self.create_comparison_levels()

# TODO: property?
@abstractmethod
def create_comparison_levels(self) -> List[ComparisonLevelCreator]:
67 changes: 53 additions & 14 deletions splink/comparison_level_library.py
Original file line number Diff line number Diff line change
@@ -36,22 +36,44 @@ def _translate_sql_string(
return tree.sql(dialect=to_sqlglot_dialect)


def validate_distance_threshold(
def validate_numeric_parameter(
lower_bound: Union[int, float],
upper_bound: Union[int, float],
distance_threshold: Union[int, float],
parameter_value: Union[int, float],
level_name: str,
parameter_name: str = "distance_threshold",
) -> Union[int, float]:
"""Check if a distance threshold falls between two bounds."""
if lower_bound <= distance_threshold <= upper_bound:
return distance_threshold
if not isinstance(parameter_value, (int, float)):
raise TypeError(
f"'{parameter_name}' must be numeric, but received type "
f"{type(parameter_value)}"
)
if lower_bound <= parameter_value <= upper_bound:
return parameter_value
else:
raise ValueError(
"'distance_threshold' must be between "
f"'{parameter_name}' must be between "
f"{lower_bound} and {upper_bound} for {level_name}"
)


def validate_categorical_parameter(
allowed_values: List[str],
parameter_value: str,
level_name: str,
parameter_name: str,
) -> Union[int, float]:
"""Check if a distance threshold falls between two bounds."""
if parameter_value in allowed_values:
return parameter_value
else:
comma_quote_separated_options = "', '".join(allowed_values)
raise ValueError(
f"'{parameter_name}' must be one of: " f"'{comma_quote_separated_options}'"
)


class NullLevel(ComparisonLevelCreator):
def __init__(
self,
@@ -301,10 +323,10 @@ def __init__(
"""

self.col_expression = ColumnExpression.instantiate_if_str(col_name)
self.distance_threshold = validate_distance_threshold(
self.distance_threshold = validate_numeric_parameter(
lower_bound=0,
upper_bound=1,
distance_threshold=distance_threshold,
parameter_value=distance_threshold,
level_name=self.__class__.__name__,
)

@@ -336,10 +358,10 @@ def __init__(
"""

self.col_expression = ColumnExpression.instantiate_if_str(col_name)
self.distance_threshold = validate_distance_threshold(
self.distance_threshold = validate_numeric_parameter(
lower_bound=0,
upper_bound=1,
distance_threshold=distance_threshold,
parameter_value=distance_threshold,
level_name=self.__class__.__name__,
)

@@ -371,10 +393,10 @@ def __init__(
"""

self.col_expression = ColumnExpression.instantiate_if_str(col_name)
self.distance_threshold = validate_distance_threshold(
self.distance_threshold = validate_numeric_parameter(
lower_bound=0,
upper_bound=1,
distance_threshold=distance_threshold,
parameter_value=distance_threshold,
level_name=self.__class__.__name__,
)

@@ -459,8 +481,19 @@ def __init__(
date_format (str): The format of the date string
"""
self.col_expression = ColumnExpression.instantiate_if_str(col_name)
self.date_threshold = date_threshold
self.date_metric = date_metric
self.date_threshold = validate_numeric_parameter(
lower_bound=0,
upper_bound=float("inf"),
parameter_value=date_threshold,
level_name=self.__class__.__name__,
parameter_name="date_threshold",
)
self.date_metric = validate_categorical_parameter(
allowed_values=["day", "month", "year"],
parameter_value=date_metric,
level_name=self.__class__.__name__,
parameter_name="date_metric",
)

@unsupported_splink_dialects(["sqlite"])
def create_sql(self, sql_dialect: SplinkDialect) -> str:
@@ -572,7 +605,13 @@ def __init__(self, col_name: str, min_intersection: int):
"""

self.col_expression = ColumnExpression.instantiate_if_str(col_name)
self.min_intersection = min_intersection
self.min_intersection = validate_numeric_parameter(
lower_bound=0,
upper_bound=float("inf"),
parameter_value=min_intersection,
level_name=self.__class__.__name__,
parameter_name="min_intersection",
)

@unsupported_splink_dialects(["sqlite"])
def create_sql(self, sql_dialect: SplinkDialect) -> str:
12 changes: 12 additions & 0 deletions splink/comparison_library.py
Original file line number Diff line number Diff line change
@@ -451,6 +451,18 @@ def __init__(
date_thresholds_as_iterable = ensure_is_iterable(date_thresholds)
self.date_thresholds = [*date_thresholds_as_iterable]

num_metrics = len(self.date_metrics)
num_thresholds = len(self.date_thresholds)
if num_thresholds == 0:
raise ValueError("`date_thresholds` must have at least one entry")
if num_metrics == 0:
raise ValueError("`date_metrics` must have at least one entry")
if num_metrics != num_thresholds:
raise ValueError(
"`date_thresholds` and `date_metrics` must have "
"the same number of entries"
)

self.cast_strings_to_dates = cast_strings_to_dates
self.date_format = date_format

13 changes: 12 additions & 1 deletion splink/comparison_template_library.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,18 @@ def __init__(
self.date_thresholds = [*date_thresholds_as_iterable]
date_metrics_as_iterable = ensure_is_iterable(datediff_metrics)
self.date_metrics = [*date_metrics_as_iterable]
# TODO: check lengths match!

num_metrics = len(self.date_metrics)
num_thresholds = len(self.date_thresholds)
if num_thresholds == 0:
raise ValueError("`date_thresholds` must have at least one entry")
if num_metrics == 0:
raise ValueError("`date_metrics` must have at least one entry")
if num_metrics != num_thresholds:
raise ValueError(
"`date_thresholds` and `date_metrics` must have "
"the same number of entries"
)

self.date_format = date_format
self.invalid_dates_as_null = invalid_dates_as_null

0 comments on commit fe45feb

Please sign in to comment.