diff --git a/splink/linker.py b/splink/linker.py index ae0116f48e..f173e86a4f 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -250,18 +250,38 @@ def __init__( self.debug_mode = False @property - def _get_input_columns( + def _input_columns( self, - as_list=True, ): """Retrieve the column names from the input dataset(s)""" - df_obj: SplinkDataFrame = next(iter(self._input_tables_dict.values())) - - column_names = ( - [col.name() for col in df_obj.columns] if as_list else df_obj.columns - ) + input_dfs = self._input_tables_dict.values() + + # get a list of the column names for each input frame + # sort it for consistent ordering, and give each frame's + # columns as a tuple so we can hash it + column_names_by_input_df = [ + tuple(sorted([col.name() for col in input_df.columns])) + for input_df in input_dfs + ] + # check that the set of input columns is the same for each frame, + # fail if the sets are different + if len(set(column_names_by_input_df)) > 1: + common_cols = set.intersection( + *(set(col_names) for col_names in column_names_by_input_df) + ) + problem_names = { + col + for frame_col_names in column_names_by_input_df + for col in frame_col_names + if col not in common_cols + } + raise SplinkException( + "All linker input frames must have the same set of columns. " + "The following columns were not found in all input frames: " + + ", ".join(problem_names) + ) - return column_names + return next(iter(input_dfs)).columns @property def _cache_uid(self): @@ -3044,8 +3064,9 @@ def missingness_chart(self, input_dataset: str = None): Args: input_dataset (str, optional): Name of one of the input tables in the - database. If provided, missingness will be computed for this table alone. - Defaults to None. + database. If provided, missingness will be computed for + this table alone. + Defaults to None. Examples: ```py diff --git a/splink/missingness.py b/splink/missingness.py index 197c307f8d..bd5711bd6c 100644 --- a/splink/missingness.py +++ b/splink/missingness.py @@ -40,13 +40,13 @@ def missingness_sqls(columns, input_tablename): def missingness_data(linker, input_tablename): + columns = linker._input_columns if input_tablename is None: splink_dataframe = linker._initialise_df_concat(materialise=True) else: splink_dataframe = linker._table_to_splink_dataframe( input_tablename, input_tablename ) - columns = splink_dataframe.columns sqls = missingness_sqls(columns, splink_dataframe.physical_name) diff --git a/splink/profile_data.py b/splink/profile_data.py index bcf341777d..f09d6340f6 100644 --- a/splink/profile_data.py +++ b/splink/profile_data.py @@ -232,7 +232,7 @@ def profile_columns(linker, column_expressions=None, top_n=10, bottom_n=10): """ if not column_expressions: - column_expressions = linker._get_input_columns + column_expressions = [col.name() for col in linker._input_columns] df_concat = linker._initialise_df_concat() diff --git a/tests/test_missingness.py b/tests/test_missingness.py new file mode 100644 index 0000000000..e4c4110705 --- /dev/null +++ b/tests/test_missingness.py @@ -0,0 +1,35 @@ +import pandas as pd +from pytest import raises + +from splink.exceptions import SplinkException +from tests.decorator import mark_with_dialects_excluding + + +@mark_with_dialects_excluding() +def test_missingness_chart(dialect, test_helpers): + helper = test_helpers[dialect] + + df = helper.load_frame_from_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + + linker = helper.Linker( + df, {"link_type": "dedupe_only"}, **helper.extra_linker_args() + ) + linker.missingness_chart() + + +@mark_with_dialects_excluding() +def test_missingness_chart_mismatched_columns(dialect, test_helpers): + helper = test_helpers[dialect] + + df_l = helper.load_frame_from_csv( + "./tests/datasets/fake_1000_from_splink_demos.csv" + ) + df_r = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + df_r.rename(columns={"surname": "SURNAME"}, inplace=True) + df_r = helper.convert_frame(df_r) + + linker = helper.Linker( + [df_l, df_r], {"link_type": "link_only"}, **helper.extra_linker_args() + ) + with raises(SplinkException): + linker.missingness_chart()