Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert all InputColumn methods that take no arguments to properties #1730

Merged
merged 6 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion splink/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _sql_gen_where_condition(link_type, unique_id_cols):
source_dataset_col = unique_id_cols[0]
where_condition = (
f"where {id_expr_l} < {id_expr_r} "
f"and l.{source_dataset_col.name()} != r.{source_dataset_col.name()}"
f"and l.{source_dataset_col.name} != r.{source_dataset_col.name}"
)

return where_condition
Expand Down
2 changes: 1 addition & 1 deletion splink/cluster_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _size_density_sql(
clusters_table = df_clustered.physical_name

input_col = InputColumn(_unique_id_col)
unique_id_col_l = input_col.name_l()
unique_id_col_l = input_col.name_l

sqls = []
sql = f"""
Expand Down
16 changes: 8 additions & 8 deletions splink/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ def _columns_to_select_for_comparison_vector_values(self):
output_cols = []
for col in input_cols:
if self._settings_obj._retain_matching_columns:
output_cols.extend(col.names_l_r())
output_cols.extend(col.names_l_r)

output_cols.append(self._case_statement)

for cl in self.comparison_levels:
if cl._has_tf_adjustments:
col = cl._tf_adjustment_input_column
output_cols.extend(col.tf_name_l_r())
output_cols.extend(col.tf_name_l_r)

return dedupe_preserving_order(output_cols)

Expand All @@ -230,7 +230,7 @@ def _columns_to_select_for_bayes_factor_parts(self):
output_cols = []
for col in input_cols:
if self._settings_obj._retain_matching_columns:
output_cols.extend(col.names_l_r())
output_cols.extend(col.names_l_r)

output_cols.append(self._gamma_column_name)

Expand All @@ -240,7 +240,7 @@ def _columns_to_select_for_bayes_factor_parts(self):
and self._settings_obj._retain_intermediate_calculation_columns
):
col = cl._tf_adjustment_input_column
output_cols.extend(col.tf_name_l_r())
output_cols.extend(col.tf_name_l_r)

# Bayes factor case when statement
sqls = [cl._bayes_factor_sql for cl in self.comparison_levels]
Expand Down Expand Up @@ -268,7 +268,7 @@ def _columns_to_select_for_predict(self):
output_cols = []
for col in input_cols:
if self._settings_obj._retain_matching_columns:
output_cols.extend(col.names_l_r())
output_cols.extend(col.names_l_r)

if (
self._settings_obj._training_mode
Expand All @@ -282,7 +282,7 @@ def _columns_to_select_for_predict(self):
and self._settings_obj._retain_intermediate_calculation_columns
):
col = cl._tf_adjustment_input_column
output_cols.extend(col.tf_name_l_r())
output_cols.extend(col.tf_name_l_r)

for _col in input_cols:
if self._settings_obj._retain_intermediate_calculation_columns:
Expand Down Expand Up @@ -445,7 +445,7 @@ def _comparison_level_description_list(self):
@property
def _human_readable_description_succinct(self):
input_cols = join_list_with_commas_final_and(
[c.name() for c in self._input_columns_used_by_case_statement]
[c.name for c in self._input_columns_used_by_case_statement]
)

comp_levels = self._comparison_level_description_list
Expand All @@ -463,7 +463,7 @@ def _human_readable_description_succinct(self):
@property
def human_readable_description(self):
input_cols = join_list_with_commas_final_and(
[c.name() for c in self._input_columns_used_by_case_statement]
[c.name for c in self._input_columns_used_by_case_statement]
)

comp_levels = self._comparison_level_description_list
Expand Down
18 changes: 6 additions & 12 deletions splink/comparison_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _tf_adjustment_input_column(self):
def _tf_adjustment_input_column_name(self):
input_column = self._tf_adjustment_input_column
if input_column:
return input_column.unquote().name()
return input_column.unquote().name

@property
def _has_comparison(self):
Expand Down Expand Up @@ -465,11 +465,9 @@ def _columns_to_select_for_blocking(self):
cols = self._input_columns_used_by_sql_condition

for c in cols:
output_cols.extend(c.l_r_names_as_l_r())
output_cols.extend(c.l_r_names_as_l_r)
if self._tf_adjustment_input_column:
output_cols.extend(
self._tf_adjustment_input_column.l_r_tf_names_as_l_r()
)
output_cols.extend(self._tf_adjustment_input_column.l_r_tf_names_as_l_r)

return dedupe_preserving_order(output_cols)

Expand Down Expand Up @@ -577,12 +575,8 @@ def _tf_adjustment_sql(self):
else:
tf_adj_col = self._tf_adjustment_input_column

coalesce_l_r = (
f"coalesce({tf_adj_col.tf_name_l()}, {tf_adj_col.tf_name_r()})"
)
coalesce_r_l = (
f"coalesce({tf_adj_col.tf_name_r()}, {tf_adj_col.tf_name_l()})"
)
coalesce_l_r = f"coalesce({tf_adj_col.tf_name_l}, {tf_adj_col.tf_name_r})"
coalesce_r_l = f"coalesce({tf_adj_col.tf_name_r}, {tf_adj_col.tf_name_l})"

tf_adjustment_exists = f"{coalesce_l_r} is not null"
u_prob_exact_match = self._u_probability_corresponding_to_exact_match
Expand Down Expand Up @@ -730,7 +724,7 @@ def _human_readable_succinct(self):
@property
def human_readable_description(self):
input_cols = join_list_with_commas_final_and(
[c.name() for c in self._input_columns_used_by_sql_condition]
[c.name for c in self._input_columns_used_by_sql_condition]
)
desc = (
f"Comparison level: {self.label_for_charts} of {input_cols}\n"
Expand Down
26 changes: 13 additions & 13 deletions splink/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
valid_string_pattern = valid_string_regex

col = InputColumn(col_name, sql_dialect=self._sql_dialect)
col_name_l, col_name_r = col.name_l(), col.name_r()
col_name_l, col_name_r = col.name_l, col.name_r

if invalid_dates_as_null:
col_name_l = self._valid_date_function(col_name_l, valid_string_pattern)
Expand Down Expand Up @@ -231,7 +231,7 @@ def __init__(
else:
label_suffix = ""

col_name_l, col_name_r = col.name_l(), col.name_r()
col_name_l, col_name_r = col.name_l, col.name_r

if set_to_lowercase:
col_name_l = f"lower({col_name_l})"
Expand Down Expand Up @@ -395,7 +395,7 @@ def __init__(
else:
operator = "<="

col_name_l, col_name_r = col.name_l(), col.name_r()
col_name_l, col_name_r = col.name_l, col.name_r

if set_to_lowercase:
col_name_l = f"lower({col_name_l})"
Expand Down Expand Up @@ -938,8 +938,8 @@ def __init__(
col_1 = InputColumn(col_name_1, sql_dialect=self._sql_dialect)
col_2 = InputColumn(col_name_2, sql_dialect=self._sql_dialect)

col_1_l, col_1_r = col_1.name_l(), col_1.name_r()
col_2_l, col_2_r = col_2.name_l(), col_2.name_r()
col_1_l, col_1_r = col_1.name_l, col_1.name_r
col_2_l, col_2_r = col_2.name_l, col_2.name_r

if set_to_lowercase:
col_1_l = f"lower({col_1_l})"
Expand Down Expand Up @@ -1030,8 +1030,8 @@ def __init__(

lat = InputColumn(lat_col, sql_dialect=self._sql_dialect)
long = InputColumn(long_col, sql_dialect=self._sql_dialect)
lat_l, lat_r = lat.names_l_r()
long_l, long_r = long.names_l_r()
lat_l, lat_r = lat.names_l_r
long_l, long_r = long.names_l_r

distance_km_sql = f"""
{great_circle_distance_km_sql(lat_l, lat_r, long_l, long_r)} <= {km_threshold}
Expand Down Expand Up @@ -1108,11 +1108,11 @@ def __init__(
"""
col = InputColumn(col_name, sql_dialect=self._sql_dialect)

s = f"""(abs({col.name_l()} - {col.name_r()})/
s = f"""(abs({col.name_l} - {col.name_r})/
(case
when {col.name_r()} > {col.name_l()}
then {col.name_r()}
else {col.name_l()}
when {col.name_r} > {col.name_l}
then {col.name_r}
else {col.name_l}
end))
< {percentage_distance_threshold}"""

Expand Down Expand Up @@ -1178,7 +1178,7 @@ def __init__(
col = InputColumn(col_name, sql_dialect=self._sql_dialect)

size_array_intersection = (
f"{self._size_array_intersect_function(col.name_l(), col.name_r())}"
f"{self._size_array_intersect_function(col.name_l, col.name_r)}"
)
sql = f"{size_array_intersection} >= {min_intersection}"

Expand Down Expand Up @@ -1359,7 +1359,7 @@ def __init__(
"""

date = InputColumn(date_col, sql_dialect=self._sql_dialect)
date_l, date_r = date.names_l_r()
date_l, date_r = date.names_l_r

datediff_sql = self._datediff_function(
date_l,
Expand Down
6 changes: 3 additions & 3 deletions splink/find_matches_to_new_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def add_unique_id_and_source_dataset_cols_if_needed(
linker: "Linker", new_records_df: "SplinkDataFrame"
):
cols = new_records_df.columns
cols = [c.unquote().name() for c in cols]
cols = [c.unquote().name for c in cols]

# Add source dataset column to new records if required and not exists
sds_sel_sql = ""
Expand All @@ -21,15 +21,15 @@ def add_unique_id_and_source_dataset_cols_if_needed(
# TODO: Shouldn't be necessary but the source dataset properties on settings
# are currently broken
sds_col = InputColumn(sds_col, linker._settings_obj)
sds_col = sds_col.unquote().name()
sds_col = sds_col.unquote().name
if sds_col not in cols:
sds_sel_sql = f", 'new_record' as {sds_col}"

# Add unique_id column to new records if not exists
uid_sel_sql = ""
uid_col = linker._settings_obj._unique_id_column_name
uid_col = InputColumn(uid_col, linker._settings_obj)
uid_col = uid_col.unquote().name()
uid_col = uid_col.unquote().name
if uid_col not in cols:
uid_sel_sql = f", 'no_id_provided' as {uid_col}"

Expand Down
31 changes: 23 additions & 8 deletions splink/input_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,74 +168,89 @@ def tf_prefix(self) -> str:
"_tf_prefix", "term_frequency_adjustment_column_prefix"
)

@property
def name(self) -> str:
return self.input_name_as_tree.sql(dialect=self._sql_dialect)

@property
def name_l(self) -> str:
return add_suffix(self.input_name_as_tree, suffix="_l").sql(
dialect=self._sql_dialect
)

@property
def name_r(self) -> str:
return add_suffix(self.input_name_as_tree, suffix="_r").sql(
dialect=self._sql_dialect
)

@property
def names_l_r(self) -> list[str]:
return [self.name_l(), self.name_r()]
return [self.name_l, self.name_r]

@property
def l_name_as_l(self) -> str:
name_with_l_table = add_table(self.input_name_as_tree, "l").sql(
dialect=self._sql_dialect
)
return f"{name_with_l_table} as {self.name_l()}"
return f"{name_with_l_table} as {self.name_l}"

@property
def r_name_as_r(self) -> str:
name_with_r_table = add_table(self.input_name_as_tree, "r").sql(
dialect=self._sql_dialect
)
return f"{name_with_r_table} as {self.name_r()}"
return f"{name_with_r_table} as {self.name_r}"

@property
def l_r_names_as_l_r(self) -> list[str]:
return [self.l_name_as_l(), self.r_name_as_r()]
return [self.l_name_as_l, self.r_name_as_r]

@property
def bf_name(self) -> str:
return add_prefix(self.input_name_as_tree, prefix=self.bf_prefix).sql(
dialect=self._sql_dialect
)

@property
def tf_name(self) -> str:
return add_prefix(self.input_name_as_tree, prefix=self.tf_prefix).sql(
dialect=self._sql_dialect
)

@property
def tf_name_l(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
return add_suffix(tree, suffix="_l").sql(dialect=self._sql_dialect)

@property
def tf_name_r(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
return add_suffix(tree, suffix="_r").sql(dialect=self._sql_dialect)

@property
def tf_name_l_r(self) -> list[str]:
return [self.tf_name_l(), self.tf_name_r()]
return [self.tf_name_l, self.tf_name_r]

@property
def l_tf_name_as_l(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
tf_name_with_l_table = add_table(tree, tablename="l").sql(
dialect=self._sql_dialect
)
return f"{tf_name_with_l_table} as {self.tf_name_l()}"
return f"{tf_name_with_l_table} as {self.tf_name_l}"

@property
def r_tf_name_as_r(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
tf_name_with_r_table = add_table(tree, tablename="r").sql(
dialect=self._sql_dialect
)
return f"{tf_name_with_r_table} as {self.tf_name_r()}"
return f"{tf_name_with_r_table} as {self.tf_name_r}"

@property
def l_r_tf_names_as_l_r(self) -> list[str]:
return [self.l_tf_name_as_l(), self.r_tf_name_as_r()]
return [self.l_tf_name_as_l, self.r_tf_name_as_r]

def _quote_if_sql_keyword(self, name: str) -> str:
if name not in {"group", "index"}:
Expand Down
2 changes: 1 addition & 1 deletion splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _input_columns(
# sort it for consistent ordering, and give each frame's
# columns as a tuple so we can hash it
column_names_by_input_df = [
tuple(sorted([col.name() for col in input_df.columns]))
tuple(sorted([col.name for col in input_df.columns]))
for input_df in input_dfs
]
# check that the set of input columns is the same for each frame,
Expand Down
2 changes: 1 addition & 1 deletion splink/lower_id_on_lhs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def lower_id_to_left_hand_side(
""" # noqa

cols = df.columns
cols = [c.unquote().name() for c in cols]
cols = [c.unquote().name for c in cols]

l_cols = [c for c in cols if c.endswith("_l")]
r_cols = [c for c in cols if c.endswith("_r")]
Expand Down
4 changes: 2 additions & 2 deletions splink/missingness.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def missingness_sqls(columns, input_tablename):

selects = [
col_template.format(
col_name_escaped=col.name(),
col_name=col.unquote().name(),
col_name_escaped=col.name,
col_name=col.unquote().name,
input_tablename=input_tablename,
)
for col in columns
Expand Down
3 changes: 1 addition & 2 deletions splink/profile_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def profile_columns(linker, column_expressions=None, top_n=10, bottom_n=10):
"""

if not column_expressions:
column_expressions = [col.name() for col in linker._input_columns]
column_expressions = [col.name for col in linker._input_columns]

df_concat = linker._initialise_df_concat()

Expand Down Expand Up @@ -297,7 +297,6 @@ def profile_columns(linker, column_expressions=None, top_n=10, bottom_n=10):
inner_charts.append(inner_chart)

if inner_charts != []:

outer_spec = deepcopy(_outer_chart_spec_freq)
outer_spec["vconcat"] = inner_charts

Expand Down
Loading
Loading