From 50a37db8b081b968e28e3f557ad6a673188136d4 Mon Sep 17 00:00:00 2001 From: Oege Dijk Date: Mon, 11 Mar 2024 22:06:32 +0100 Subject: [PATCH] rewrite merge_categorical_columns to avoid fragmentation warning --- explainerdashboard/explainer_methods.py | 37 +++++++++---------------- 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/explainerdashboard/explainer_methods.py b/explainerdashboard/explainer_methods.py index 16c9235..8464424 100644 --- a/explainerdashboard/explainer_methods.py +++ b/explainerdashboard/explainer_methods.py @@ -388,44 +388,33 @@ def retrieve_onehot_value( def merge_categorical_columns( X, onehot_dict=None, cols=None, not_encoded_dict=None, sep="_", drop_regular=False ): - """ - Returns a new feature Dataframe X_cats where the onehotencoded - categorical features have been merged back with the old value retrieved - from the encodings. + cat_pieces = [] - Args: - X (pd.DataFrame): original dataframe with onehotencoded columns, e.g. - columns=['Age', 'Sex_Male', 'Sex_Female"]. - onehot_dict (dict): dict of features with lists for onehot-encoded variables, - e.g. {'Fare': ['Fare'], 'Sex' : ['Sex_male', 'Sex_Female']} - cols (list[str]): list of columns to return - sep (str): separator used in the encoding, e.g. "_" for Sex_Male. - Defaults to "_". - - Returns: - pd.DataFrame, with onehot encodings merged back into categorical columns. - """ - X_cats = pd.DataFrame() - not_encoded_dict = not_encoded_dict or {} for col_name, col_list in onehot_dict.items(): if len(col_list) > 1: - X_cats[col_name] = retrieve_onehot_value( + merged_col = retrieve_onehot_value( X, col_name, col_list, not_encoded_dict.get(col_name, "NOT_ENCODED"), sep, ).astype("category") + cat_pieces.append(pd.DataFrame({col_name: merged_col})) else: if not drop_regular: if is_categorical_dtype(X[col_name]): - X_cats[col_name] = pd.Categorical(X[col_name]) + cat_pieces.append( + pd.DataFrame({col_name: pd.Categorical(X[col_name])}) + ) else: - X_cats.loc[:, col_name] = X[col_name].values + cat_pieces.append(pd.DataFrame({col_name: X[col_name].values})) + + X_cats = pd.concat(cat_pieces, axis=1) + if cols: - return X_cats[cols] - else: - return X_cats + X_cats = X_cats[cols] + + return X_cats def matching_cols(cols1, cols2):