Skip to content

Commit

Permalink
Merge pull request #1730 from moj-analytical-services/1688-inputcolum…
Browse files Browse the repository at this point in the history
…n-methods-could-be-properties

Convert all InputColumn methods that take no arguments to properties
  • Loading branch information
RobinL authored Nov 14, 2023
2 parents 4cb5438 + 593a93b commit 4374aa1
Show file tree
Hide file tree
Showing 19 changed files with 119 additions and 111 deletions.
2 changes: 1 addition & 1 deletion splink/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _sql_gen_where_condition(link_type, unique_id_cols):
source_dataset_col = unique_id_cols[0]
where_condition = (
f"where {id_expr_l} < {id_expr_r} "
f"and l.{source_dataset_col.name()} != r.{source_dataset_col.name()}"
f"and l.{source_dataset_col.name} != r.{source_dataset_col.name}"
)

return where_condition
Expand Down
2 changes: 1 addition & 1 deletion splink/cluster_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _size_density_sql(
clusters_table = df_clustered.physical_name

input_col = InputColumn(_unique_id_col)
unique_id_col_l = input_col.name_l()
unique_id_col_l = input_col.name_l

sqls = []
sql = f"""
Expand Down
16 changes: 8 additions & 8 deletions splink/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ def _columns_to_select_for_comparison_vector_values(self):
output_cols = []
for col in input_cols:
if self._settings_obj._retain_matching_columns:
output_cols.extend(col.names_l_r())
output_cols.extend(col.names_l_r)

output_cols.append(self._case_statement)

for cl in self.comparison_levels:
if cl._has_tf_adjustments:
col = cl._tf_adjustment_input_column
output_cols.extend(col.tf_name_l_r())
output_cols.extend(col.tf_name_l_r)

return dedupe_preserving_order(output_cols)

Expand All @@ -230,7 +230,7 @@ def _columns_to_select_for_bayes_factor_parts(self):
output_cols = []
for col in input_cols:
if self._settings_obj._retain_matching_columns:
output_cols.extend(col.names_l_r())
output_cols.extend(col.names_l_r)

output_cols.append(self._gamma_column_name)

Expand All @@ -240,7 +240,7 @@ def _columns_to_select_for_bayes_factor_parts(self):
and self._settings_obj._retain_intermediate_calculation_columns
):
col = cl._tf_adjustment_input_column
output_cols.extend(col.tf_name_l_r())
output_cols.extend(col.tf_name_l_r)

# Bayes factor case when statement
sqls = [cl._bayes_factor_sql for cl in self.comparison_levels]
Expand Down Expand Up @@ -268,7 +268,7 @@ def _columns_to_select_for_predict(self):
output_cols = []
for col in input_cols:
if self._settings_obj._retain_matching_columns:
output_cols.extend(col.names_l_r())
output_cols.extend(col.names_l_r)

if (
self._settings_obj._training_mode
Expand All @@ -282,7 +282,7 @@ def _columns_to_select_for_predict(self):
and self._settings_obj._retain_intermediate_calculation_columns
):
col = cl._tf_adjustment_input_column
output_cols.extend(col.tf_name_l_r())
output_cols.extend(col.tf_name_l_r)

for _col in input_cols:
if self._settings_obj._retain_intermediate_calculation_columns:
Expand Down Expand Up @@ -445,7 +445,7 @@ def _comparison_level_description_list(self):
@property
def _human_readable_description_succinct(self):
input_cols = join_list_with_commas_final_and(
[c.name() for c in self._input_columns_used_by_case_statement]
[c.name for c in self._input_columns_used_by_case_statement]
)

comp_levels = self._comparison_level_description_list
Expand All @@ -463,7 +463,7 @@ def _human_readable_description_succinct(self):
@property
def human_readable_description(self):
input_cols = join_list_with_commas_final_and(
[c.name() for c in self._input_columns_used_by_case_statement]
[c.name for c in self._input_columns_used_by_case_statement]
)

comp_levels = self._comparison_level_description_list
Expand Down
18 changes: 6 additions & 12 deletions splink/comparison_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _tf_adjustment_input_column(self):
def _tf_adjustment_input_column_name(self):
input_column = self._tf_adjustment_input_column
if input_column:
return input_column.unquote().name()
return input_column.unquote().name

@property
def _has_comparison(self):
Expand Down Expand Up @@ -465,11 +465,9 @@ def _columns_to_select_for_blocking(self):
cols = self._input_columns_used_by_sql_condition

for c in cols:
output_cols.extend(c.l_r_names_as_l_r())
output_cols.extend(c.l_r_names_as_l_r)
if self._tf_adjustment_input_column:
output_cols.extend(
self._tf_adjustment_input_column.l_r_tf_names_as_l_r()
)
output_cols.extend(self._tf_adjustment_input_column.l_r_tf_names_as_l_r)

return dedupe_preserving_order(output_cols)

Expand Down Expand Up @@ -577,12 +575,8 @@ def _tf_adjustment_sql(self):
else:
tf_adj_col = self._tf_adjustment_input_column

coalesce_l_r = (
f"coalesce({tf_adj_col.tf_name_l()}, {tf_adj_col.tf_name_r()})"
)
coalesce_r_l = (
f"coalesce({tf_adj_col.tf_name_r()}, {tf_adj_col.tf_name_l()})"
)
coalesce_l_r = f"coalesce({tf_adj_col.tf_name_l}, {tf_adj_col.tf_name_r})"
coalesce_r_l = f"coalesce({tf_adj_col.tf_name_r}, {tf_adj_col.tf_name_l})"

tf_adjustment_exists = f"{coalesce_l_r} is not null"
u_prob_exact_match = self._u_probability_corresponding_to_exact_match
Expand Down Expand Up @@ -730,7 +724,7 @@ def _human_readable_succinct(self):
@property
def human_readable_description(self):
input_cols = join_list_with_commas_final_and(
[c.name() for c in self._input_columns_used_by_sql_condition]
[c.name for c in self._input_columns_used_by_sql_condition]
)
desc = (
f"Comparison level: {self.label_for_charts} of {input_cols}\n"
Expand Down
26 changes: 13 additions & 13 deletions splink/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
valid_string_pattern = valid_string_regex

col = InputColumn(col_name, sql_dialect=self._sql_dialect)
col_name_l, col_name_r = col.name_l(), col.name_r()
col_name_l, col_name_r = col.name_l, col.name_r

if invalid_dates_as_null:
col_name_l = self._valid_date_function(col_name_l, valid_string_pattern)
Expand Down Expand Up @@ -231,7 +231,7 @@ def __init__(
else:
label_suffix = ""

col_name_l, col_name_r = col.name_l(), col.name_r()
col_name_l, col_name_r = col.name_l, col.name_r

if set_to_lowercase:
col_name_l = f"lower({col_name_l})"
Expand Down Expand Up @@ -395,7 +395,7 @@ def __init__(
else:
operator = "<="

col_name_l, col_name_r = col.name_l(), col.name_r()
col_name_l, col_name_r = col.name_l, col.name_r

if set_to_lowercase:
col_name_l = f"lower({col_name_l})"
Expand Down Expand Up @@ -938,8 +938,8 @@ def __init__(
col_1 = InputColumn(col_name_1, sql_dialect=self._sql_dialect)
col_2 = InputColumn(col_name_2, sql_dialect=self._sql_dialect)

col_1_l, col_1_r = col_1.name_l(), col_1.name_r()
col_2_l, col_2_r = col_2.name_l(), col_2.name_r()
col_1_l, col_1_r = col_1.name_l, col_1.name_r
col_2_l, col_2_r = col_2.name_l, col_2.name_r

if set_to_lowercase:
col_1_l = f"lower({col_1_l})"
Expand Down Expand Up @@ -1030,8 +1030,8 @@ def __init__(

lat = InputColumn(lat_col, sql_dialect=self._sql_dialect)
long = InputColumn(long_col, sql_dialect=self._sql_dialect)
lat_l, lat_r = lat.names_l_r()
long_l, long_r = long.names_l_r()
lat_l, lat_r = lat.names_l_r
long_l, long_r = long.names_l_r

distance_km_sql = f"""
{great_circle_distance_km_sql(lat_l, lat_r, long_l, long_r)} <= {km_threshold}
Expand Down Expand Up @@ -1108,11 +1108,11 @@ def __init__(
"""
col = InputColumn(col_name, sql_dialect=self._sql_dialect)

s = f"""(abs({col.name_l()} - {col.name_r()})/
s = f"""(abs({col.name_l} - {col.name_r})/
(case
when {col.name_r()} > {col.name_l()}
then {col.name_r()}
else {col.name_l()}
when {col.name_r} > {col.name_l}
then {col.name_r}
else {col.name_l}
end))
< {percentage_distance_threshold}"""

Expand Down Expand Up @@ -1178,7 +1178,7 @@ def __init__(
col = InputColumn(col_name, sql_dialect=self._sql_dialect)

size_array_intersection = (
f"{self._size_array_intersect_function(col.name_l(), col.name_r())}"
f"{self._size_array_intersect_function(col.name_l, col.name_r)}"
)
sql = f"{size_array_intersection} >= {min_intersection}"

Expand Down Expand Up @@ -1359,7 +1359,7 @@ def __init__(
"""

date = InputColumn(date_col, sql_dialect=self._sql_dialect)
date_l, date_r = date.names_l_r()
date_l, date_r = date.names_l_r

datediff_sql = self._datediff_function(
date_l,
Expand Down
6 changes: 3 additions & 3 deletions splink/find_matches_to_new_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def add_unique_id_and_source_dataset_cols_if_needed(
linker: "Linker", new_records_df: "SplinkDataFrame"
):
cols = new_records_df.columns
cols = [c.unquote().name() for c in cols]
cols = [c.unquote().name for c in cols]

# Add source dataset column to new records if required and not exists
sds_sel_sql = ""
Expand All @@ -21,15 +21,15 @@ def add_unique_id_and_source_dataset_cols_if_needed(
# TODO: Shouldn't be necessary but the source dataset properties on settings
# are currently broken
sds_col = InputColumn(sds_col, linker._settings_obj)
sds_col = sds_col.unquote().name()
sds_col = sds_col.unquote().name
if sds_col not in cols:
sds_sel_sql = f", 'new_record' as {sds_col}"

# Add unique_id column to new records if not exists
uid_sel_sql = ""
uid_col = linker._settings_obj._unique_id_column_name
uid_col = InputColumn(uid_col, linker._settings_obj)
uid_col = uid_col.unquote().name()
uid_col = uid_col.unquote().name
if uid_col not in cols:
uid_sel_sql = f", 'no_id_provided' as {uid_col}"

Expand Down
31 changes: 23 additions & 8 deletions splink/input_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,74 +168,89 @@ def tf_prefix(self) -> str:
"_tf_prefix", "term_frequency_adjustment_column_prefix"
)

@property
def name(self) -> str:
return self.input_name_as_tree.sql(dialect=self._sql_dialect)

@property
def name_l(self) -> str:
return add_suffix(self.input_name_as_tree, suffix="_l").sql(
dialect=self._sql_dialect
)

@property
def name_r(self) -> str:
return add_suffix(self.input_name_as_tree, suffix="_r").sql(
dialect=self._sql_dialect
)

@property
def names_l_r(self) -> list[str]:
return [self.name_l(), self.name_r()]
return [self.name_l, self.name_r]

@property
def l_name_as_l(self) -> str:
name_with_l_table = add_table(self.input_name_as_tree, "l").sql(
dialect=self._sql_dialect
)
return f"{name_with_l_table} as {self.name_l()}"
return f"{name_with_l_table} as {self.name_l}"

@property
def r_name_as_r(self) -> str:
name_with_r_table = add_table(self.input_name_as_tree, "r").sql(
dialect=self._sql_dialect
)
return f"{name_with_r_table} as {self.name_r()}"
return f"{name_with_r_table} as {self.name_r}"

@property
def l_r_names_as_l_r(self) -> list[str]:
return [self.l_name_as_l(), self.r_name_as_r()]
return [self.l_name_as_l, self.r_name_as_r]

@property
def bf_name(self) -> str:
return add_prefix(self.input_name_as_tree, prefix=self.bf_prefix).sql(
dialect=self._sql_dialect
)

@property
def tf_name(self) -> str:
return add_prefix(self.input_name_as_tree, prefix=self.tf_prefix).sql(
dialect=self._sql_dialect
)

@property
def tf_name_l(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
return add_suffix(tree, suffix="_l").sql(dialect=self._sql_dialect)

@property
def tf_name_r(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
return add_suffix(tree, suffix="_r").sql(dialect=self._sql_dialect)

@property
def tf_name_l_r(self) -> list[str]:
return [self.tf_name_l(), self.tf_name_r()]
return [self.tf_name_l, self.tf_name_r]

@property
def l_tf_name_as_l(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
tf_name_with_l_table = add_table(tree, tablename="l").sql(
dialect=self._sql_dialect
)
return f"{tf_name_with_l_table} as {self.tf_name_l()}"
return f"{tf_name_with_l_table} as {self.tf_name_l}"

@property
def r_tf_name_as_r(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
tf_name_with_r_table = add_table(tree, tablename="r").sql(
dialect=self._sql_dialect
)
return f"{tf_name_with_r_table} as {self.tf_name_r()}"
return f"{tf_name_with_r_table} as {self.tf_name_r}"

@property
def l_r_tf_names_as_l_r(self) -> list[str]:
return [self.l_tf_name_as_l(), self.r_tf_name_as_r()]
return [self.l_tf_name_as_l, self.r_tf_name_as_r]

def _quote_if_sql_keyword(self, name: str) -> str:
if name not in {"group", "index"}:
Expand Down
2 changes: 1 addition & 1 deletion splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _input_columns(
# sort it for consistent ordering, and give each frame's
# columns as a tuple so we can hash it
column_names_by_input_df = [
tuple(sorted([col.name() for col in input_df.columns]))
tuple(sorted([col.name for col in input_df.columns]))
for input_df in input_dfs
]
# check that the set of input columns is the same for each frame,
Expand Down
2 changes: 1 addition & 1 deletion splink/lower_id_on_lhs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def lower_id_to_left_hand_side(
""" # noqa

cols = df.columns
cols = [c.unquote().name() for c in cols]
cols = [c.unquote().name for c in cols]

l_cols = [c for c in cols if c.endswith("_l")]
r_cols = [c for c in cols if c.endswith("_r")]
Expand Down
4 changes: 2 additions & 2 deletions splink/missingness.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def missingness_sqls(columns, input_tablename):

selects = [
col_template.format(
col_name_escaped=col.name(),
col_name=col.unquote().name(),
col_name_escaped=col.name,
col_name=col.unquote().name,
input_tablename=input_tablename,
)
for col in columns
Expand Down
3 changes: 1 addition & 2 deletions splink/profile_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def profile_columns(linker, column_expressions=None, top_n=10, bottom_n=10):
"""

if not column_expressions:
column_expressions = [col.name() for col in linker._input_columns]
column_expressions = [col.name for col in linker._input_columns]

df_concat = linker._initialise_df_concat()

Expand Down Expand Up @@ -297,7 +297,6 @@ def profile_columns(linker, column_expressions=None, top_n=10, bottom_n=10):
inner_charts.append(inner_chart)

if inner_charts != []:

outer_spec = deepcopy(_outer_chart_spec_freq)
outer_spec["vconcat"] = inner_charts

Expand Down
Loading

0 comments on commit 4374aa1

Please sign in to comment.