Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1064 profile array elements #1397

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions splink/duckdb/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ def columns(self) -> list[InputColumn]:
col_strings = list(d.keys())
return [InputColumn(c, sql_dialect="duckdb") for c in col_strings]

@property
def get_table_schema(self):
sql = f"DESCRIBE {self.physical_name}"
return self.linker._con.query(sql).to_df()

def get_array_cols(self):
schema = self.get_table_schema
return [
col
for col, type in zip(schema.column_name, schema.column_type)
if type.endswith("[]")
]

def validate(self):
pass

Expand Down
50 changes: 48 additions & 2 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2023,9 +2023,55 @@ def cluster_pairwise_predictions_at_threshold(
return cc

def profile_columns(
self, column_expressions: str | list[str], top_n=10, bottom_n=10
self,
column_expressions: str | list[str],
top_n=10,
bottom_n=10,
cast_arrays_as_str=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is False by default too, so users are going to get spammed by the warning.

This won't be so much of an issue once this is implemented across the board, but it might be a little obnoxious if we don't apply it across the board for a while.

):
return profile_columns(self, column_expressions, top_n=top_n, bottom_n=bottom_n)
"""Generate three summary charts for each `column_expression`.

The purpose of these charts is to better understand the cardinality and skew of
each `column_expression` chosen for profiling. Further information can be found
on the Exploratory analysis
[tutorial](https://moj-analytical-services.github.io/splink/demos/02_
Exploratory_analysis.html#analyse-the-distribution-of-values-in-your-data)


Args:
column_expressions (str | list[str]): The columns to be profiled or
an sql expression
top_n (int): The number of bars in the chart for the most common values.
bottom_n (int): The number of bars in the chart for the least common values.
cast_arrays_as_str (bool): If False, any columns with arrays will be
unnested and the elements of the array will be profiled. If set
to True, the whole array will be profiled.

Returns:
altair_or_json: An altair line graph, a bar graph of the most common values
and a bar chart of the least common values, for each column profiled.

Examples:
```py
linker.profile_columns(["first_name", "surname", "substr(dob, 1,4)"],
top_n=10, bottom_n=5)

```

```py
linker.profile_columns([["first_name", "surname"], "email", "postcode_arr"],
cast_arrays_as_str=True)

```

"""
return profile_columns(
self,
column_expressions,
top_n=top_n,
bottom_n=bottom_n,
cast_arrays_as_str=cast_arrays_as_str,
)

def _get_labels_tablename_from_input(
self, labels_splinkdataframe_or_table_name: str | SplinkDataFrame
Expand Down
82 changes: 62 additions & 20 deletions splink/profile_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def _get_df_top_bottom_n(expressions, limit=20, value_order="desc"):
return sql


def _col_or_expr_frequencies_raw_data_sql(cols_or_exprs, table_name):
def _col_or_expr_frequencies_raw_data_sql(
cols_or_exprs, array_cols, table_name, cast_arrays_as_str
):
cols_or_exprs = ensure_is_list(cols_or_exprs)
column_expressions = expressions_to_sql(cols_or_exprs)
sqls = []
Expand All @@ -146,8 +148,12 @@ def _col_or_expr_frequencies_raw_data_sql(cols_or_exprs, table_name):

# If the supplied column string is a list of columns to be concatenated,
# add a quick clause to filter out any instances whereby either column contains
# a null value.
# a null value. Also raise error of usr tries to supply array columns

if isinstance(raw_expr, list):
if any([expr in array_cols for expr in raw_expr]):
raise ValueError("Arrays cannot be concatenated during profiling")

null_exprs = [f"{c} is null" for c in raw_expr]
null_exprs = " OR ".join(null_exprs)

Expand All @@ -159,21 +165,52 @@ def _col_or_expr_frequencies_raw_data_sql(cols_or_exprs, table_name):
end
"""

sql = f"""
select * from
(select
count(*) as value_count,
'{gn}' as group_name,
cast({col_or_expr} as varchar) as value,
(select count({col_or_expr}) from {table_name}) as total_non_null_rows,
(select count(*) from {table_name}) as total_rows_inc_nulls,
(select count(distinct {col_or_expr}) from {table_name})
as distinct_value_count
from {table_name}
where {col_or_expr} is not null
group by {col_or_expr}
order by count(*) desc) column_stats
"""
if not cast_arrays_as_str and raw_expr in array_cols:

sql = f"""
select * from
(select value,
count (*) as value_count,
'{gn}' as group_name,

(select count(value) from
(select unnest ({col_or_expr}
Copy link
Member

@RobinL RobinL Jul 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a potential issue with unnest here that the function name changes depending on the dialect e.g. in spark it's explode. If so, might consider adding something to each backend similar to how we deal with random samples (where the dialect varies between linkers). Apologies if you're already aware

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up... yes this was on my radar and was more complicated to deal with as the 'translator' didn't translate the desired function, so Tom is going to work on applying this to the other backends. But I will ask him about the solution used with random samples as it's good to know.

) AS value from {table_name})) as total_non_null_rows,

(select count(*) from
(select unnest ({col_or_expr}
) AS value from {table_name})) as total_rows_inc_nulls,

(select count(distinct value) from
(select unnest ({col_or_expr}
) AS value from {table_name})) as distinct_value_count

from
(select cast(unnest({col_or_expr}) as varchar) as value,

from {table_name})
group by value
order by count(*) desc) column_stats

"""

else:
sql = f"""
select * from
(select
cast({col_or_expr} as varchar) as value,
count(*) as value_count,
'{gn}' as group_name,
(select count({col_or_expr}) from {table_name}) as total_non_null_rows,
(select count(*) from {table_name}) as total_rows_inc_nulls,
(select count(distinct {col_or_expr}) from {table_name})
as distinct_value_count
from {table_name}
where {col_or_expr} is not null
group by {col_or_expr}
order by count(*) desc) column_stats
"""

sqls.append(sql)

return " union all ".join(sqls)
Expand All @@ -190,18 +227,23 @@ def _add_100_percentile_to_df_percentiles(percentile_rows):
return percentile_rows


def profile_columns(linker, column_expressions, top_n=10, bottom_n=10):
df_concat = linker._initialise_df_concat()
def profile_columns(
linker, column_expressions, top_n=10, bottom_n=10, cast_arrays_as_str=False
):

df_concat = linker._initialise_df_concat(materialise=True)

input_dataframes = []
if df_concat:
input_dataframes.append(df_concat)

array_cols = df_concat.get_array_cols()

column_expressions_raw = ensure_is_list(column_expressions)
column_expressions = expressions_to_sql(column_expressions_raw)

sql = _col_or_expr_frequencies_raw_data_sql(
column_expressions_raw, "__splink__df_concat"
column_expressions_raw, array_cols, df_concat.physical_name, cast_arrays_as_str
)

linker._enqueue_sql(sql, "__splink__df_all_column_value_frequencies")
Expand Down
4 changes: 4 additions & 0 deletions splink/splink_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def _drop_table_from_database(self, force_non_splink_table=False):
"_drop_table_from_database from database not " "implemented for this linker"
)

def get_array_cols(self):
logger.warning("Profiling arrays is not implemented for this linker")
return []

def drop_table_from_database_and_remove_from_cache(
self, force_non_splink_table=False
):
Expand Down
Loading