From 1c279839ff195ac328fe7eb8a5376c8577e80311 Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Fri, 13 Sep 2024 16:26:34 +0200 Subject: [PATCH 1/3] Add feature importance local and cumulative --- setup.py | 2 +- shapash/backend/base_backend.py | 8 +- shapash/explainer/multi_decorator.py | 10 +- shapash/explainer/smart_explainer.py | 27 +- shapash/explainer/smart_plotter.py | 1213 ++++++++++++----- shapash/explainer/smart_state.py | 8 +- shapash/manipulation/filters.py | 6 +- shapash/manipulation/summarize.py | 8 +- shapash/style/colors.json | 24 +- shapash/style/style_utils.py | 3 + shapash/webapp/smart_app.py | 2 +- .../explainer/test_smart_plotter.py | 283 +++- .../unit_tests/webapp/utils/test_callbacks.py | 34 +- 13 files changed, 1223 insertions(+), 405 deletions(-) diff --git a/setup.py b/setup.py index 53f525a3..a12cfaed 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ version=version_d["__version__"], python_requires=">3.8, <3.13", url="https://github.com/MAIF/shapash", - author="Yann Golhen, Sebastien Bidault, Yann Lagre, Maxime Gendre", + author="Yann Golhen, Sebastien Bidault, Yann Lagre, Maxime Gendre, Thomas Bouché, Maxime Lecardonnel, Guillaume Vignal", author_email="yann.golhen@maif.fr", description="Shapash is a Python library which aims to make machine learning interpretable and understandable by everyone.", long_description=long_description, diff --git a/shapash/backend/base_backend.py b/shapash/backend/base_backend.py index 3e5206b3..821c4a4d 100644 --- a/shapash/backend/base_backend.py +++ b/shapash/backend/base_backend.py @@ -107,7 +107,11 @@ def get_local_contributions( return local_contributions def get_global_features_importance( - self, contributions: pd.DataFrame, explain_data: Optional[dict] = None, subset: Optional[List[int]] = None + self, + contributions: pd.DataFrame, + explain_data: Optional[dict] = None, + subset: Optional[List[int]] = None, + norm: int = 1, ) -> Union[pd.Series, List[pd.Series]]: """Get global contributions using the explainer data computed in the `run_explainer` method. @@ -132,7 +136,7 @@ def get_global_features_importance( contributions = [c.loc[subset] for c in contributions] else: contributions = contributions.loc[subset] - return state.compute_features_import(contributions) + return state.compute_features_import(contributions, norm) def format_and_aggregate_local_contributions( self, diff --git a/shapash/explainer/multi_decorator.py b/shapash/explainer/multi_decorator.py index f734d10f..ffeeca49 100644 --- a/shapash/explainer/multi_decorator.py +++ b/shapash/explainer/multi_decorator.py @@ -226,23 +226,23 @@ def summarize(self, s_contribs, var_dicts, xs_sorted, masks, columns_dict, featu arg_tup = list(zip(s_contribs, var_dicts, xs_sorted, masks)) return self.delegate("summarize", arg_tup, columns_dict, features_dict) - def compute_features_import(self, contributions): + def compute_features_import(self, contributions, norm=1): """ Compute a relative features importance, sum of absolute values - ​​of the contributions for each - features importance compute in base 100 + ​​of the contributions for each + features importance compute in base 100 Parameters ---------- contributions : list - list of pandas.DataFrames containing contributions + list of pandas.DataFrames containing contributions Returns ------- list list of features importance pandas.series """ - return self.delegate("compute_features_import", contributions) + return self.delegate("compute_features_import", contributions, norm) def compute_grouped_contributions(self, contributions, features_groups): """ diff --git a/shapash/explainer/smart_explainer.py b/shapash/explainer/smart_explainer.py index 8d3143e7..465ce1dc 100644 --- a/shapash/explainer/smart_explainer.py +++ b/shapash/explainer/smart_explainer.py @@ -1,6 +1,7 @@ """ Smart explainer module """ + import copy import logging import shutil @@ -217,13 +218,12 @@ def __init__( self.backend_kwargs = backend_kwargs self.features_dict = dict() if features_dict is None else copy.deepcopy(features_dict) self.label_dict = label_dict - self.plot = SmartPlotter(self) self.title_story = title_story if title_story is not None else "" self.palette_name = palette_name if palette_name else "default" self.colors_dict = copy.deepcopy(select_palette(colors_loading(), self.palette_name)) if colors_dict is not None: self.colors_dict.update(colors_dict) - self.plot.define_style_attributes(colors_dict=self.colors_dict) + self.plot = SmartPlotter(self, self.colors_dict) self._case, self._classes = check_model(self.model) self.postprocessing = postprocessing @@ -359,7 +359,7 @@ def _compile_features_groups(self, features_groups): Performs required computations for groups of features. """ if self.backend.support_groups is False: - raise AssertionError(f"Selected backend ({self.backend.name}) " f"does not support groups of features.") + raise AssertionError(f"Selected backend ({self.backend.name}) does not support groups of features.") # Compute contributions for groups of features self.contributions_groups = self.state.compute_grouped_contributions(self.contributions, features_groups) self.features_imp_groups = None @@ -931,7 +931,7 @@ def to_pandas( return pd.concat([y_pred, summary], axis=1) - def compute_features_import(self, force=False): + def compute_features_import(self, force=False, local=False): """ Compute a relative features importance, sum of absolute values of the contributions for each. @@ -949,11 +949,26 @@ def compute_features_import(self, force=False): index of the serie = contributions.columns """ self.features_imp = self.backend.get_global_features_importance( - contributions=self.contributions, explain_data=self.explain_data, subset=None + contributions=self.contributions, explain_data=self.explain_data, subset=None, norm=1 ) + if local: + self.features_imp_local_lev1 = self.backend.get_global_features_importance( + contributions=self.contributions, explain_data=self.explain_data, subset=None, norm=3 + ) + self.features_imp_local_lev2 = self.backend.get_global_features_importance( + contributions=self.contributions, explain_data=self.explain_data, subset=None, norm=7 + ) + if self.features_groups is not None and self.features_imp_groups is None: - self.features_imp_groups = self.state.compute_features_import(self.contributions_groups) + self.features_imp_groups = self.state.compute_features_import(self.contributions_groups, norm=1) + if local: + self.features_imp_groups_local_lev1 = self.state.compute_features_import( + self.contributions_groups, norm=3 + ) + self.features_imp_groups_local_lev2 = self.state.compute_features_import( + self.contributions_groups, norm=7 + ) def compute_features_stability(self, selection): """ diff --git a/shapash/explainer/smart_plotter.py b/shapash/explainer/smart_plotter.py index 3497ecb0..bf38e538 100644 --- a/shapash/explainer/smart_plotter.py +++ b/shapash/explainer/smart_plotter.py @@ -12,6 +12,7 @@ import pandas as pd import plotly.express as px import scipy.cluster.hierarchy as sch +from matplotlib.colors import LinearSegmentedColormap from plotly import graph_objs as go from plotly.offline import plot from plotly.subplots import make_subplots @@ -19,7 +20,7 @@ from shapash.manipulation.select_lines import select_lines from shapash.manipulation.summarize import compute_corr, project_feature_values_1d -from shapash.style.style_utils import colors_loading, define_style, select_palette +from shapash.style.style_utils import colors_loading, define_style, get_pyplot_color, select_palette from shapash.utils.utils import ( add_line_break, add_text, @@ -46,13 +47,16 @@ class SmartPlotter: >>> xpl.plot.my_plot_method(param=value) """ - def __init__(self, explainer): - self.explainer = explainer - self._palette_name = list(colors_loading().keys())[0] - self._style_dict = define_style(select_palette(colors_loading(), self._palette_name)) - self.round_digit = None - self.last_stability_selection = False - self.last_compacity_selection = False + def __init__(self, explainer, colors_dict=None): + self._explainer = explainer + if colors_dict: + self._style_dict = define_style(colors_dict) + else: + palette_name = list(colors_loading().keys())[0] + self._style_dict = define_style(select_palette(colors_loading(), palette_name)) + self._round_digit = None + self._last_stability_selection = False + self._last_compacity_selection = False def define_style_attributes(self, colors_dict): """ @@ -64,7 +68,7 @@ def define_style_attributes(self, colors_dict): """ self._style_dict = define_style(colors_dict) - def tuning_colorscale(self, values, keep_90_pct=False): + def _tuning_colorscale(self, values, keep_90_pct=False): """ Adjusts the color scale based on the distribution of points. @@ -127,15 +131,15 @@ def tuning_colorscale(self, values, keep_90_pct=False): return color_scale, cmin, cmax - def tuning_round_digit(self): + def _tuning_round_digit(self): """ adapts the display of the number of digit to the distribution of points """ quantile = [0.25, 0.75] - desc_df = self.explainer.y_pred.describe(percentiles=quantile) + desc_df = self._explainer.y_pred.describe(percentiles=quantile) perc1, perc2 = list(desc_df.loc[[str(int(p * 100)) + "%" for p in quantile]].values) p_diff = perc2 - perc1 - self.round_digit = compute_digit_number(p_diff) + self._round_digit = compute_digit_number(p_diff) def _update_contributions_fig( self, @@ -192,7 +196,6 @@ def _update_contributions_fig( title = f"{truncate_str(feature_name)} - Feature Contribution" # Add subtitle and / or addnote if subtitle or addnote: - # title += f"
{add_text([subtitle, addnote], sep=' - ')}
" if subtitle and addnote: title += "
" + subtitle + " - " + addnote + "" elif subtitle: @@ -206,10 +209,10 @@ def _update_contributions_fig( dict_xaxis["text"] = truncate_str(feature_name, 110) dict_yaxis["text"] = "Contribution" - if self.explainer._case == "regression": + if self._explainer._case == "regression": colorpoints = pred colorbar_title = "Predicted" - elif self.explainer._case == "classification": + elif self._explainer._case == "classification": colorpoints = proba_values colorbar_title = "Predicted Proba" @@ -224,7 +227,7 @@ def _update_contributions_fig( fig.layout.coloraxis.cmax = cmax elif fig.data[0].type != "violin": - if self.explainer._case == "classification" and pred is not None: + if self._explainer._case == "classification" and pred is not None: fig.data[-1].marker.color = pred.iloc[:, 0].apply( lambda x: ( self._style_dict["violin_area_classif"][1] @@ -256,7 +259,7 @@ def _update_contributions_fig( if file_name: plot(fig, filename=file_name, auto_open=auto_open) - def plot_scatter( + def _plot_scatter( self, feature_values, contributions, @@ -325,7 +328,7 @@ def plot_scatter( proba_values = proba_values.loc[feature_values.index] # add break line to X label if necessary - max_len_by_row = max([round(50 / self.explainer.features_desc[feature_values.columns.values[0]]), 8]) + max_len_by_row = max([round(50 / self._explainer.features_desc[feature_values.columns.values[0]]), 8]) feature_values.iloc[:, 0] = feature_values.iloc[:, 0].apply( add_line_break, args=( @@ -551,7 +554,7 @@ def _create_jittered_points( return jittered_points - def prepare_hover_text(self, feature_values, pred, feature_name): + def _prepare_hover_text(self, feature_values, pred, feature_name): """ Prepares the hover text for a Plotly plot based on feature values and predictions. @@ -668,7 +671,7 @@ def _add_violin_and_scatter( self._add_scatter_trace(fig, x, y, c, marker, hovertext, hovertemplate, customdata, secondary_y) - def plot_violin( + def _plot_violin( self, feature_values, contributions, @@ -731,7 +734,7 @@ def plot_violin( column_name = feature_values.columns[0] feature_values = feature_values.sort_values(by=column_name) - max_len_by_row = max([round(50 / self.explainer.features_desc[feature_values.columns.values[0]]), 8]) + max_len_by_row = max([round(50 / self._explainer.features_desc[feature_values.columns.values[0]]), 8]) feature_values.iloc[:, 0] = feature_values.iloc[:, 0].apply( add_line_break, args=( @@ -745,7 +748,7 @@ def plot_violin( if proba_values is not None: proba_values = proba_values.loc[feature_values.index] - hv_text_df, hovertemplate = self.prepare_hover_text(feature_values, pred, feature_name) + hv_text_df, hovertemplate = self._prepare_hover_text(feature_values, pred, feature_name) feature_values_counts = feature_values.value_counts() xs = feature_values_counts.index.get_level_values(0).sort_values() @@ -753,9 +756,9 @@ def plot_violin( y_upper = (feature_values_counts.loc[xs] / feature_values_counts.sum()).values.flatten() y_upper_max = y_upper.max() - if self.explainer._case == "classification": + if self._explainer._case == "classification": colorpoints = proba_values - elif self.explainer._case == "regression": + elif self._explainer._case == "regression": colorpoints = pred else: colorpoints = None @@ -778,7 +781,7 @@ def plot_violin( ) ) - if pred is not None and self.explainer._case == "classification": + if pred is not None and self._explainer._case == "classification": # Negative case feature_cond_neg = (pred.iloc[:, 0] != col_modality) & (feature_values.iloc[:, 0] == c) self._add_violin_and_scatter( @@ -892,7 +895,7 @@ def plot_violin( return fig - def plot_features_import( + def _plot_features_import( self, feature_imp1, feature_imp2=None, @@ -955,9 +958,9 @@ def plot_features_import( ( self._style_dict["featureimp_groups"][0] if ( - self.explainer.features_groups is not None - and self.explainer.inv_features_dict.get(f.replace("", "").replace("", "")) - in self.explainer.features_groups.keys() + self._explainer.features_groups is not None + and self._explainer.inv_features_dict.get(f.replace("", "").replace("", "")) + in self._explainer.features_groups.keys() ) else dict_style_bar1["color"] ) @@ -977,7 +980,7 @@ def plot_features_import( margin={"l": 160, "r": 0, "t": topmargin, "b": 50}, ) # To change ticktext when the x label size is upper than 30 and zoom is False - if (type(feature_imp1.index[0]) == str) & (not zoom): + if (isinstance(feature_imp1.index[0], str)) & (not zoom): # change index to abc...abc if its length is upper than 30 index_val = [y.replace(y[24 : len(y) - 3], "...") if len(y) > 30 else y for y in feature_imp1.index] else: @@ -1014,7 +1017,325 @@ def plot_features_import( plot(fig, filename=file_name, auto_open=auto_open) return fig - def plot_bar_chart( + def _plot_local_features_import( + self, + feat_imp, + title="Features Importance Global-Local", + addnote=None, + subtitle=None, + width=900, + height=500, + file_name=None, + auto_open=False, + zoom=False, + ): + """ + Plot features importance computed with the prediction set. + Parameters + ---------- + feat_imp : dict of pd.Series + Feature importance computed with every rows :global, semi-local and local + title : str + Title of the plot, default set to 'Features Importance' + addnote : String (default: None) + Specify a note to display + subtitle : String (default: None) + Subtitle to display + width : Int (default: 900) + Plotly figure - layout width + height : Int (default: 500) + Plotly figure - layout height + file_name: string (optional) + Specify the save path of html files. If it is not provided, no file will be saved. + auto_open: bool (default=False) + open automatically the plot + zoom: bool (default=False) + graph is currently zoomed + """ + dict_t = copy.deepcopy(self._style_dict["dict_title"]) + topmargin = 80 + # Add subtitle and / or addnote + if subtitle or addnote: + if subtitle and addnote: + title += "
" + subtitle + " - " + addnote + "" + elif subtitle: + title += "
" + subtitle + "" + else: + title += "
" + addnote + "" + topmargin = topmargin + 15 + dict_t.update(text=title) + dict_xaxis = copy.deepcopy(self._style_dict["dict_xaxis"]) + dict_xaxis.update(text="Mean absolute Contribution") + dict_yaxis = copy.deepcopy(self._style_dict["dict_yaxis"]) + dict_yaxis.update(text=None) + dict_style_bar = {} + for type_feat, i in zip(["global", "semi-local", "local"], [1, 3, 4]): + dict_style_bar[type_feat] = self._style_dict["dict_featimp_colors"][i] + dict_yaxis["text"] = None + + # Change bar color for groups of features + marker_color = [ + ( + self._style_dict["featureimp_groups"][0] + if ( + self._explainer.features_groups is not None + and self._explainer.inv_features_dict.get(f.replace("", "").replace("", "")) + in self._explainer.features_groups.keys() + ) + else dict_style_bar["global"]["color"] + ) + for f in feat_imp["global"].index + ] + + layout = go.Layout( + barmode="group", + template="none", + autosize=False, + width=width, + height=height, + title=dict_t, + xaxis_title=dict_xaxis, + yaxis_title=dict_yaxis, + hovermode="closest", + margin={"l": 160, "r": 0, "t": topmargin, "b": 50}, + ) + + data = [] + for type_feat in ["local", "semi-local", "global"]: + feature_imp = feat_imp[type_feat] + style_bar = dict_style_bar[type_feat] + + data.append( + go.Bar( + x=feature_imp.round(4), + y=feature_imp.index, + orientation="h", + name=type_feat.capitalize(), + marker=style_bar, + marker_color=marker_color if type_feat == "global" else style_bar["color"], + hovertemplate="Feature: %{customdata}
Contribution: %{x:.4f}", + customdata=feature_imp.index, + ) + ) + + fig = go.Figure(data=data, layout=layout) + + # Update ticktext + # To change ticktext when the x label size is upper than 30 and zoom is False + if (isinstance(feat_imp["global"].index[0], str)) & (not zoom): + # change index to abc...abc if its length is upper than 30 + index_val = [y.replace(y[24 : len(y) - 3], "...") if len(y) > 30 else y for y in feat_imp["global"].index] + else: + index_val = feat_imp["global"].index + fig.update_yaxes(ticktext=index_val, tickvals=feat_imp["global"].index, tickmode="array", dtick=1) + fig.update_yaxes(automargin=True) + if file_name: + plot(fig, filename=file_name, auto_open=auto_open) + return fig + + def _plot_feature_contributions_cumulative( + self, + feature_imp1, + contributions_case, + title="Feature Contributions Cumulative Plot", + addnote=None, + subtitle=None, + width=900, + height=500, + normalize_by_nb_samples=False, + degree="slider", + file_name=None, + auto_open=False, + zoom=False, + ): + """ + Generates a cumulative plot of feature contributions with a slider to adjust the degree. + + Parameters: + - feature_imp1: DataFrame of feature importances. + - title (str): The title of the plot. + - width (int): The width of the plot in pixels. + - height (int): The height of the plot in pixels. + - normalize_by_nb_samples (bool): Whether to normalize the contributions by the number of samples. + - degree (str or float): The degree of normalization to apply. Use 'slider' for interactive degree control. + - file_name (str, optional): Specify the save path of HTML files. If not provided, no file will be saved. + - auto_open (bool): Whether to automatically open the plot. + + Returns: + - fig (plotly.graph_objs._figure.Figure): The generated cumulative plot figure. + """ + # Number of features + num_features = len(feature_imp1) + + # Generate color scale + col_scale = get_pyplot_color(colors=self._style_dict["feature_contributions_cumulative"]) + cmap = LinearSegmentedColormap.from_list("feature_contributions_cumulative", col_scale, N=256) + colors = [cmap(i / num_features) for i in range(num_features)] + colors_hex = ["#{:02x}{:02x}{:02x}".format(int(r * 255), int(g * 255), int(b * 255)) for r, g, b, _ in colors] + + # Initialize data for storing the series + data = [] + serie_tot = None + + lst_feat = list(feature_imp1.index)[::-1] + lst_feat = [f.replace("", "").replace("", "") for f in lst_feat] + + # Process each feature's contributions and compute cumulative sums + for name in lst_feat: + serie = ( + contributions_case[self._explainer.inv_features_dict.get(name)] + .abs() + .sort_values(ascending=False) + .cumsum() + .reset_index(drop=True) + ) + data.append(serie) + + # Accumulate the total series for normalization + if serie_tot is None: + serie_tot = serie.copy() + else: + serie_tot += serie + + # Create the Plotly traces for each series + dict_t = copy.deepcopy(self._style_dict["dict_title"]) + topmargin = 80 + # Add subtitle and / or addnote + if subtitle or addnote: + if subtitle and addnote: + title += "
" + subtitle + " - " + addnote + "" + elif subtitle: + title += "
" + subtitle + "" + else: + title += "
" + addnote + "" + topmargin = topmargin + 15 + dict_t.update(text=title) + + if (isinstance(lst_feat[0], str)) & (not zoom): + # change index to abc...abc if its length is upper than 30 + index_val = [y.replace(y[24 : len(y) - 3], "...") if len(y) > 30 else y for y in lst_feat] + else: + index_val = lst_feat + + figs = [] + for i, serie in enumerate(data): + serie_values = serie.copy() + + # Optionally normalize by the number of samples + if normalize_by_nb_samples: + serie_values /= pd.Series(range(1, len(serie_values) + 1)) + + # Apply initial degree-based normalization + if degree not in [0, "slider"]: + serie_values /= serie_tot ** degree + + # Append the trace for the current series + figs.append( + go.Scatter( + x=serie.index, + y=serie_values, + mode="lines", + name=index_val[i], + hoverinfo="text", # Use 'text' to refer to custom hovertext + text=lst_feat[i], # Set custom text for hover + line=dict(color=colors_hex[i], width=3), + hoverlabel=dict( + font_size=12, # Optional: adjust font size for better readability + ), + ) + ) + + # Define layout with a clean white background and title + layout = go.Layout( + title=dict_t, + xaxis=dict(visible=False), + yaxis=dict(visible=False, autorange=True), + plot_bgcolor="white", + paper_bgcolor="white", + width=width, + height=height, + margin={"l": 10, "r": 0, "t": topmargin, "b": 10}, + ) + + # Create the initial figure with the data and layout + fig = go.Figure(data=figs, layout=layout) + + # Create a list of frames with updated data for each degree value + if degree == "slider": + frames = [] + degree_range = np.round(np.arange(0, 1.1, 0.1), 1) + + for deg in degree_range: + new_figs = [] + max_y = 0 # Track max value for y-axis rescaling + for i, serie in enumerate(data): + serie_values = serie.copy() + + if normalize_by_nb_samples: + serie_values /= pd.Series(range(1, len(serie_values) + 1)) + + # Apply degree-based normalization + if deg != 0: + serie_values /= serie_tot ** (-deg) + + max_y = max(max_y, serie_values.max()) + + new_figs.append( + go.Scatter( + x=serie.index, + y=serie_values, + mode="lines", + hoverinfo="text", # Use 'text' to refer to custom hovertext + text=lst_feat[i], # Set custom text for hover + line=dict(color=colors_hex[i], width=3), + hoverlabel=dict( + font_size=12, # Optional: adjust font size for better readability + ), + ) + ) + + # Layout for this degree value, adjusting y-axis range + frame_layout = go.Layout( + yaxis=dict(visible=False, autorange=True), + plot_bgcolor="white", + paper_bgcolor="white", + width=width, + height=height, + ) + + # Append each frame with its own layout + frames.append(go.Frame(data=new_figs, name=f"degree_{deg}", layout=frame_layout)) + + # Add slider to control the degree parameter + sliders = [ + { + "currentvalue": {"prefix": "Degree: "}, + "pad": {"b": 10}, + "steps": [ + { + "args": [ + [f"degree_{deg}"], + {"frame": {"duration": 300, "redraw": True}, "mode": "immediate"}, + ], + "label": str(deg), + "method": "animate", + } + for deg in degree_range + ], + } + ] + + # Add frames and sliders to the figure + fig.update(frames=frames) + fig.update_layout(sliders=sliders) + + # Optionally save the plot to a file + if file_name: + plot(fig, filename=file_name, auto_open=auto_open) + + return fig + + def _plot_bar_chart( self, index_value, var_dict, @@ -1097,18 +1418,18 @@ def plot_bar_chart( # If bar is a group of features, hovertext includes the values of the features of the group # And color changes if ( - self.explainer.features_groups is not None - and self.explainer.inv_features_dict.get(expl[0]) in self.explainer.features_groups.keys() + self._explainer.features_groups is not None + and self._explainer.inv_features_dict.get(expl[0]) in self._explainer.features_groups.keys() and len(index_value) > 0 ): - group_name = self.explainer.inv_features_dict.get(expl[0]) - feat_groups_values = self.explainer.x_init[self.explainer.features_groups[group_name]].loc[ + group_name = self._explainer.inv_features_dict.get(expl[0]) + feat_groups_values = self._explainer.x_init[self._explainer.features_groups[group_name]].loc[ index_value[0] ] hoverlabel = "
".join( [ "{} :{}".format( - add_line_break(self.explainer.features_dict.get(f_name, f_name), 40, maxlen=120), + add_line_break(self._explainer.features_dict.get(f_name, f_name), 40, maxlen=120), add_line_break(f_value, 40, maxlen=160), ) for f_name, f_value in feat_groups_values.to_dict().items() @@ -1129,12 +1450,12 @@ def plot_bar_chart( else: trunc_new_value = trunc_value if len(contrib) <= yaxis_max_label and ( - self.explainer.features_groups is None + self._explainer.features_groups is None # We don't want to display label values for t-sne projected values of groups of features. or ( - self.explainer.features_groups is not None - and self.explainer.inv_features_dict.get(expl[0]) - not in self.explainer.features_groups.keys() + self._explainer.features_groups is not None + and self._explainer.inv_features_dict.get(expl[0]) + not in self._explainer.features_groups.keys() ) ): # ylabel is based on trunc_new_value @@ -1175,7 +1496,6 @@ def plot_bar_chart( fig = go.Figure(data=[x[-1] for x in bars], layout=layout) fig.update_yaxes(dtick=1) fig.update_yaxes(automargin=True) - # fig.update_xaxes(automargin=True) if file_name: plot(fig, filename=file_name, auto_open=auto_open) @@ -1196,7 +1516,7 @@ def plot_bar_chart( ) return fig - def get_selection(self, line, var_dict, x_val, contrib): + def _get_selection(self, line, var_dict, x_val, contrib): """ An auxiliary function to select the row of interest. Parameters @@ -1222,7 +1542,7 @@ def get_selection(self, line, var_dict, x_val, contrib): return var_dict, x_val, contrib - def apply_mask_one_line(self, line, var_dict, x_val, contrib, label=None): + def _apply_mask_one_line(self, line, var_dict, x_val, contrib, label=None): """ An auxiliary function to select the mask to apply before plotting local explanation. @@ -1245,11 +1565,11 @@ def apply_mask_one_line(self, line, var_dict, x_val, contrib, label=None): Masked input lists. """ mask = np.array([True] * len(contrib)) - if hasattr(self.explainer, "mask"): - if isinstance(self.explainer.mask, list): - mask = self.explainer.mask[label].loc[line[0], :].values + if hasattr(self._explainer, "mask"): + if isinstance(self._explainer.mask, list): + mask = self._explainer.mask[label].loc[line[0], :].values else: - mask = self.explainer.mask.loc[line[0], :].values + mask = self._explainer.mask.loc[line[0], :].values contrib = contrib[mask] x_val = x_val[mask] @@ -1257,7 +1577,7 @@ def apply_mask_one_line(self, line, var_dict, x_val, contrib, label=None): return var_dict.tolist(), x_val.tolist(), contrib.tolist() - def check_masked_contributions(self, line, var_dict, x_val, contrib, label=None): + def _check_masked_contributions(self, line, var_dict, x_val, contrib, label=None): """ Check for masked contributions and update features_values and contrib to take the sum of masked contributions into account. @@ -1277,17 +1597,17 @@ def check_masked_contributions(self, line, var_dict, x_val, contrib, label=None) numpy arrays Input arrays updated with masked contributions. """ - if hasattr(self.explainer, "masked_contributions"): - if isinstance(self.explainer.masked_contributions, list): - ext_contrib = self.explainer.masked_contributions[label].loc[line[0], :].values + if hasattr(self._explainer, "masked_contributions"): + if isinstance(self._explainer.masked_contributions, list): + ext_contrib = self._explainer.masked_contributions[label].loc[line[0], :].values else: - ext_contrib = self.explainer.masked_contributions.loc[line[0], :].values + ext_contrib = self._explainer.masked_contributions.loc[line[0], :].values ext_var_dict = ["Hidden Negative Contributions", "Hidden Positive Contributions"] ext_x = ["", ""] ext_contrib = ext_contrib.tolist() - exclusion = np.where(np.array(ext_contrib) == 0)[0].tolist() + exclusion = np.flatnonzero(np.array(ext_contrib) == 0).tolist() exclusion.sort(reverse=True) for ind in exclusion: del ext_var_dict[ind] @@ -1300,7 +1620,7 @@ def check_masked_contributions(self, line, var_dict, x_val, contrib, label=None) return var_dict, x_val, contrib - def local_pred(self, index, label=None): + def _local_pred(self, index, label=None): """ compute a local pred to display in local_plot Parameters @@ -1312,16 +1632,16 @@ def local_pred(self, index, label=None): ------- float: Predict or predict_proba value """ - if self.explainer._case == "classification": - if self.explainer.proba_values is not None: - value = self.explainer.proba_values.iloc[:, [label]].loc[index].values[0] + if self._explainer._case == "classification": + if self._explainer.proba_values is not None: + value = self._explainer.proba_values.iloc[:, [label]].loc[index].values[0] else: value = None - elif self.explainer._case == "regression": - if self.explainer.y_pred is not None: - value = self.explainer.y_pred.loc[index] + elif self._explainer._case == "regression": + if self._explainer.y_pred is not None: + value = self._explainer.y_pred.loc[index] else: - value = self.explainer.model.predict(self.explainer.x_encoded.loc[[index]])[0] + value = self._explainer.model.predict(self._explainer.x_encoded.loc[[index]])[0] if isinstance(value, pd.Series): value = value.values[0] @@ -1394,21 +1714,23 @@ def local_plot( -------- >>> xpl.plot.local_plot(row_num=0) """ - display_groups = True if (display_groups is not False and self.explainer.features_groups is not None) else False + display_groups = ( + True if (display_groups is not False and self._explainer.features_groups is not None) else False + ) if display_groups: - data = self.explainer.data_groups + data = self._explainer.data_groups else: - data = self.explainer.data + data = self._explainer.data if index is not None: - if index in self.explainer.x_init.index: + if index in self._explainer.x_init.index: line = [index] else: line = [] elif row_num is not None: - line = [self.explainer.x_init.index[row_num]] + line = [self._explainer.x_init.index[row_num]] elif query is not None: - line = select_lines(self.explainer.x_init, query) + line = select_lines(self._explainer.x_init, query) else: line = [] @@ -1424,76 +1746,76 @@ def local_plot( else: # apply filter if the method have not yet been asked in order to limit the number of feature to display if ( - not hasattr(self.explainer, "mask_params") # If the filter method has not been called yet + not hasattr(self._explainer, "mask_params") # If the filter method has not been called yet # Or if the already computed mask was not updated with current display_groups parameter or ( isinstance(data["contrib_sorted"], pd.DataFrame) - and len(data["contrib_sorted"].columns) != len(self.explainer.mask.columns) + and len(data["contrib_sorted"].columns) != len(self._explainer.mask.columns) ) or ( isinstance(data["contrib_sorted"], list) - and len(data["contrib_sorted"][0].columns) != len(self.explainer.mask[0].columns) + and len(data["contrib_sorted"][0].columns) != len(self._explainer.mask[0].columns) ) ): - self.explainer.filter(max_contrib=20, display_groups=display_groups) + self._explainer.filter(max_contrib=20, display_groups=display_groups) - if self.explainer._case == "classification": + if self._explainer._case == "classification": if label is None: label = -1 - label_num, _, label_value = self.explainer.check_label_name(label) + label_num, _, label_value = self._explainer.check_label_name(label) contrib = data["contrib_sorted"][label_num] x_val = data["x_sorted"][label_num] var_dict = data["var_dict"][label_num] if show_predict is True: - pred = self.local_pred(line[0], label_num) + pred = self._local_pred(line[0], label_num) if pred is None: subtitle = f"Response: {label_value} - No proba available" else: subtitle = f"Response: {label_value} - Proba: {pred:.4f}" - elif self.explainer._case == "regression": + elif self._explainer._case == "regression": contrib = data["contrib_sorted"] x_val = data["x_sorted"] var_dict = data["var_dict"] label_num = None if show_predict is True: - pred_value = self.local_pred(line[0]) - if self.explainer.y_pred is not None: - if self.round_digit is None: - self.tuning_round_digit() - digit = self.round_digit + pred_value = self._local_pred(line[0]) + if self._explainer.y_pred is not None: + if self._round_digit is None: + self._tuning_round_digit() + digit = self._round_digit else: digit = compute_digit_number(pred_value) subtitle = f"Predict: {round(pred_value, digit)}" - var_dict, x_val, contrib = self.get_selection(line, var_dict, x_val, contrib) - var_dict, x_val, contrib = self.apply_mask_one_line(line, var_dict, x_val, contrib, label=label_num) + var_dict, x_val, contrib = self._get_selection(line, var_dict, x_val, contrib) + var_dict, x_val, contrib = self._apply_mask_one_line(line, var_dict, x_val, contrib, label=label_num) # use label of each column if display_groups: - var_dict = [self.explainer.features_dict[self.explainer.x_init_groups.columns[x]] for x in var_dict] + var_dict = [self._explainer.features_dict[self._explainer.x_init_groups.columns[x]] for x in var_dict] else: - var_dict = [self.explainer.features_dict[self.explainer.columns_dict[x]] for x in var_dict] + var_dict = [self._explainer.features_dict[self._explainer.columns_dict[x]] for x in var_dict] if show_masked: - var_dict, x_val, contrib = self.check_masked_contributions( + var_dict, x_val, contrib = self._check_masked_contributions( line, var_dict, x_val, contrib, label=label_num ) # Filtering all negative or positive contrib if specify in mask exclusion = [] - if hasattr(self.explainer, "mask_params"): - if self.explainer.mask_params["positive"] is True: - exclusion = np.where(np.array(contrib) < 0)[0].tolist() - elif self.explainer.mask_params["positive"] is False: - exclusion = np.where(np.array(contrib) > 0)[0].tolist() + if hasattr(self._explainer, "mask_params"): + positive = self._explainer.mask_params.get("positive") + if positive is not None: + exclusion = np.flatnonzero(np.array(contrib) < 0 if positive else np.array(contrib) > 0).tolist() + exclusion.sort(reverse=True) for expl in exclusion: del var_dict[expl] del x_val[expl] del contrib[expl] - fig = self.plot_bar_chart( + fig = self._plot_bar_chart( line, var_dict, x_val, contrib, yaxis_max_label, subtitle, width, height, file_name, auto_open, zoom ) return fig @@ -1559,33 +1881,33 @@ def contribution_plot( >>> xpl.plot.contribution_plot(0) """ - if self.explainer._case == "classification": - label_num, _, label_value = self.explainer.check_label_name(label) + if self._explainer._case == "classification": + label_num, _, label_value = self._explainer.check_label_name(label) if not isinstance(col, (str, int)): raise ValueError("parameter col must be string or int.") - if hasattr(self.explainer, "inv_features_dict"): - col = self.explainer.inv_features_dict.get(col, col) - col_is_group = self.explainer.features_groups and col in self.explainer.features_groups.keys() + if hasattr(self._explainer, "inv_features_dict"): + col = self._explainer.inv_features_dict.get(col, col) + col_is_group = self._explainer.features_groups and col in self._explainer.features_groups.keys() # Case where col is a group of features if col_is_group: - contributions = self.explainer.contributions_groups - col_label = self.explainer.features_dict[col] - col_name = self.explainer.features_groups[col] # Here col_name is actually a list of features - col_value_count = self.explainer.features_desc[col] + contributions = self._explainer.contributions_groups + col_label = self._explainer.features_dict[col] + col_name = self._explainer.features_groups[col] # Here col_name is actually a list of features + col_value_count = self._explainer.features_desc[col] else: - contributions = self.explainer.contributions - col_id = self.explainer.check_features_name([col])[0] - col_name = self.explainer.columns_dict[col_id] - col_value_count = self.explainer.features_desc[col_name] + contributions = self._explainer.contributions + col_id = self._explainer.check_features_name([col])[0] + col_name = self._explainer.columns_dict[col_id] + col_value_count = self._explainer.features_desc[col_name] - if self.explainer.features_dict: - col_label = self.explainer.features_dict[col_name] + if self._explainer.features_dict: + col_label = self._explainer.features_dict[col_name] else: col_label = col_name - list_ind, addnote = self.explainer.plot._subset_sampling( + list_ind, addnote = self._explainer.plot._subset_sampling( selection, max_points, None if col_is_group else col, col_value_count ) @@ -1597,57 +1919,58 @@ def contribution_plot( cmax = None # Classification Case - if self.explainer._case == "classification": + if self._explainer._case == "classification": subcontrib = contributions[label_num] - if self.explainer.y_pred is not None: - col_value = self.explainer._classes[label_num] + if self._explainer.y_pred is not None: + col_value = self._explainer._classes[label_num] subtitle = f"Response: {label_value}" # predict proba Color scale - if proba and self.explainer.proba_values is not None: - proba_values = self.explainer.proba_values.iloc[:, [label_num]] + if proba and self._explainer.proba_values is not None: + proba_values = self._explainer.proba_values.iloc[:, [label_num]] # Proba subset: proba_values = proba_values.loc[list_ind, :] - col_scale, cmin, cmax = self.tuning_colorscale(proba_values, keep_90_pct=True) - elif self.explainer.y_pred is not None: - pred_values = self.explainer.y_pred.iloc[:, [label_num]] + col_scale, cmin, cmax = self._tuning_colorscale(proba_values, keep_90_pct=True) + elif self._explainer.y_pred is not None: + pred_values = self._explainer.y_pred.iloc[:, [label_num]] # Prediction subset: pred_values = pred_values.loc[list_ind, :] - col_scale, cmin, cmax = self.tuning_colorscale(pred_values, keep_90_pct=True) + col_scale, cmin, cmax = self._tuning_colorscale(pred_values, keep_90_pct=True) # Regression Case - color scale - elif self.explainer._case == "regression": + elif self._explainer._case == "regression": subcontrib = contributions - if self.explainer.y_pred is not None: - col_scale, cmin, cmax = self.tuning_colorscale(self.explainer.y_pred.loc[list_ind], keep_90_pct=True) + if self._explainer.y_pred is not None: + col_scale, cmin, cmax = self._tuning_colorscale(self._explainer.y_pred.loc[list_ind], keep_90_pct=True) # Subset - if self.explainer.postprocessing_modifications: - feature_values = self.explainer.x_contrib_plot.loc[list_ind, col_name] + if self._explainer.postprocessing_modifications: + feature_values = self._explainer.x_contrib_plot.loc[list_ind, col_name] else: - feature_values = self.explainer.x_init.loc[list_ind, col_name] + feature_values = self._explainer.x_init.loc[list_ind, col_name] if col_is_group: feature_values = project_feature_values_1d( feature_values, col, - self.explainer.x_init, - self.explainer.x_encoded, - self.explainer.preprocessing, - features_dict=self.explainer.features_dict, + self._explainer.x_init, + self._explainer.x_encoded, + self._explainer.preprocessing, + features_dict=self._explainer.features_dict, ) contrib = subcontrib.loc[list_ind, col].to_frame() - if self.explainer.features_imp is None: - self.explainer.compute_features_import() + if self._explainer.features_imp is None: + self._explainer.compute_features_import() features_imp = ( - self.explainer.features_imp - if isinstance(self.explainer.features_imp, pd.Series) - else self.explainer.features_imp[0] + self._explainer.features_imp + if isinstance(self._explainer.features_imp, pd.Series) + else self._explainer.features_imp[0] ) top_features_of_group = ( - features_imp.loc[self.explainer.features_groups[col]].sort_values(ascending=False)[:4].index + features_imp.loc[self._explainer.features_groups[col]].sort_values(ascending=False)[:4].index ) # Displaying top 4 features metadata = { - self.explainer.features_dict[f_name]: self.explainer.x_init[f_name] for f_name in top_features_of_group + self._explainer.features_dict[f_name]: self._explainer.x_init[f_name] + for f_name in top_features_of_group } text_group = "Features values were projected on the x axis using t-SNE" # if group don't show addnote, if not, it's too long @@ -1660,23 +1983,23 @@ def contribution_plot( metadata = None feature_values = feature_values.to_frame() - if self.explainer.y_pred is not None: - y_pred = self.explainer.y_pred.loc[list_ind] + if self._explainer.y_pred is not None: + y_pred = self._explainer.y_pred.loc[list_ind] # Add labels if exist - if self.explainer._case == "classification" and self.explainer.label_dict is not None: - y_pred = y_pred.map(lambda x: self.explainer.label_dict[x]) - col_value = self.explainer.label_dict[col_value] + if self._explainer._case == "classification" and self._explainer.label_dict is not None: + y_pred = y_pred.map(lambda x: self._explainer.label_dict[x]) + col_value = self._explainer.label_dict[col_value] # round predict - elif self.explainer._case == "regression": - if self.round_digit is None: - self.tuning_round_digit() - y_pred = y_pred.map(lambda x: round(x, self.round_digit)) + elif self._explainer._case == "regression": + if self._round_digit is None: + self._tuning_round_digit() + y_pred = y_pred.map(lambda x: round(x, self._round_digit)) else: y_pred = None # selecting the best plot : Scatter, Violin? if col_value_count > violin_maxf: - fig = self.plot_scatter( + fig = self._plot_scatter( feature_values, contrib, col_label, @@ -1696,7 +2019,7 @@ def contribution_plot( zoom, ) else: - fig = self.plot_violin( + fig = self._plot_violin( feature_values, contrib, col_label, @@ -1719,6 +2042,7 @@ def contribution_plot( def features_importance( self, + mode="global", max_features=20, page="top", selection=None, @@ -1731,54 +2055,70 @@ def features_importance( file_name=None, auto_open=False, zoom=False, + normalize_by_nb_samples=False, + degree="slider", ): """ - features_importance display a plotly features importance plot. - in Multiclass Case, this features_importance focus on a label value. - User specifies the label value using label parameter. - the selection parameter allows the user to compare a subset to the global features - importance - features_importance tutorial offers several examples - (please check tutorial part of this doc) + Display a Plotly feature importance plot. + + This method generates a feature importance plot for both classification and regression models. + For multiclass classification, the plot will focus on the specified `label`. + Parameters ---------- - max_features: int (optional, default 20) - this argument limit the number of hbar in features importance plot - if max_features is 20, plot selects the 20 most important features - page: int, str (optional, default "top") - enable to select the features to plot between "top", "worse" or the number of the page - selection: list (optional, default None) - This argument allows to represent the importance calculated with a subset. - Subset features importance is compared to global in the plot - Argument must contains list of index, subset of the input DataFrame that we want to plot - label: integer or string (default -1) - If the label is of string type, check if it can be changed to integer to select the - good dataframe object. - group_name : str (optional, default None) - Allows to display the features importance of the variables that are grouped together - inside a group of features. - This parameter is only available if the SmartExplainer object has been compiled using - the features_groups optional parameter and should correspond to a key of - features_groups dictionary. - display_groups : bool (default True) - If groups of features are declared in SmartExplainer object, this parameter allows to - specify whether or not to display them. - force: bool (optional, default False) - force == True, force the compute features importance if it's already done - width : Int (default: 900) - Plotly figure - layout width - height : Int (default: 500) - Plotly figure - layout height - file_name: string (optional) - File name to use to save the plotly bar chart. If None the bar chart will not be saved. - auto_open: Boolean (optional) - Indicate whether to open the bar plot or not. - zoom: bool (default=False) - graph is currently zoomed + mode : str, optional, default: 'global' + Defines the type of plot to display. + - 'global': Displays the feature importance plot from a global perspective. + - 'global-local': Shows the global feature importance plot with local importance indicators. + - 'cumulative': Shows the cumulative sum of feature contributions, ordered by descending importance. + max_features : int, optional, default: 20 + Limits the number of features to display in the plot. + For example, `max_features=20` will display the 20 most important features. + page : int or str, optional, default: 'top' + Allows the user to select which set of features to display. + - 'top': Shows the most important features. + - 'worst': Shows the least important features. + - Page number (integer) allows navigation between different sets of features. + selection : list, optional, default: None + Specifies a subset of features to compare to the global feature importance. + This is only applicable when `mode` is set to 'global'. If provided, the list must contain + indices corresponding to the subset of features to be displayed. + label : int or str, optional, default: -1 + Specifies the label for which to display feature importance in multiclass classification. + If a string label is provided, it will be converted to an integer if applicable. + group_name : str, optional, default: None + Displays feature importance for a specific group of features. + This is only available if the `SmartExplainer` object has been compiled with feature groups. + The group name must correspond to a key in the `features_groups` dictionary. + display_groups : bool, optional, default: True + If feature groups are declared in the `SmartExplainer` object, this parameter specifies + whether or not to display them in the plot. + force : bool, optional, default: False + If `True`, forces recomputation of feature importance, even if it has already been computed. + width : int, optional, default: 900 + The width of the Plotly figure layout. + height : int, optional, default: 500 + The height of the Plotly figure layout. + file_name : str, optional + The name of the file to save the Plotly bar chart. + If `None`, the chart will not be saved. + auto_open : bool, optional + If `True`, automatically opens the generated plot. + zoom : bool, optional, default: False + Indicates whether the graph is currently zoomed in. + normalize_by_nb_samples : bool, optional, default: False + Normalizes feature importance by the number of samples. + This is only applicable when `mode` is set to 'cumulative'. + degree : int, optional, default: 0 + Degree of adjustment to apply to the cumulative feature contributions curve. + This is only applicable when `mode` is set to 'cumulative'. + Returns ------- - Plotly Figure Object - Example + plotly.graph_objs._figure.Figure + The generated Plotly figure object containing the feature importance plot. + + Examples -------- >>> xpl.plot.features_importance() """ @@ -1800,97 +2140,246 @@ def get_feature_importance_page(features_importance, page, max_features): else: raise ValueError("Invalid value for page. It must be 'top', 'worst', or an integer.") - self.explainer.compute_features_import(force=force) - subtitle = None - title = "Features Importance" - display_groups = self.explainer.features_groups is not None and display_groups + # Compute the feature importance based on mode + self._explainer.compute_features_import(force=force, local=(mode == "global-local")) + + # Determine title based on the mode + titles = { + "global": "Feature Importance", + "global-local": "Global and Local Feature Importance", + "cumulative": "Cumulative Feature Contribution Curve", + } + title = titles.get(mode, "Feature Importance") + + # Check if feature groups should be displayed + display_groups = self._explainer.features_groups is not None and display_groups + + # Handle feature groups and group-specific cases + local_imp_lev1, local_imp_lev2 = None, None if display_groups: if group_name: # Case where we have groups of features and we want to display only features inside a group - if group_name not in self.explainer.features_groups.keys(): + if group_name not in self._explainer.features_groups.keys(): raise ValueError( f"group_name parameter : {group_name} is not in features_groups keys. " - f"Possible values are : {list(self.explainer.features_groups.keys())}" + f"Possible values are : {list(self._explainer.features_groups.keys())}" ) - title += f" - {truncate_str(self.explainer.features_dict.get(group_name), 20)}" - if isinstance(self.explainer.features_imp, list): + title += f" - {truncate_str(self._explainer.features_dict.get(group_name), 20)}" + if isinstance(self._explainer.features_imp, list): features_importance = [ - label_feat_imp.loc[label_feat_imp.index.isin(self.explainer.features_groups[group_name])] - for label_feat_imp in self.explainer.features_imp + label_feat_imp.loc[label_feat_imp.index.isin(self._explainer.features_groups[group_name])] + for label_feat_imp in self._explainer.features_imp ] + if mode == "global-local": + local_imp_lev1 = [ + label_feat_imp.loc[label_feat_imp.index.isin(self._explainer.features_groups[group_name])] + for label_feat_imp in self._explainer.features_imp_local_lev1 + ] + local_imp_lev2 = [ + label_feat_imp.loc[label_feat_imp.index.isin(self._explainer.features_groups[group_name])] + for label_feat_imp in self._explainer.features_imp_local_lev2 + ] else: - features_importance = self.explainer.features_imp.loc[ - self.explainer.features_imp.index.isin(self.explainer.features_groups[group_name]) - ] - contributions = self.explainer.contributions + index = self._explainer.features_imp.index.isin(self._explainer.features_groups[group_name]) + features_importance = self._explainer.features_imp.loc[index] + if mode == "global-local": + local_imp_lev1 = self._explainer.features_imp_local_lev1.loc[index] + local_imp_lev2 = self._explainer.features_imp_local_lev2.loc[index] + contributions = self._explainer.contributions else: - features_importance = self.explainer.features_imp_groups - contributions = self.explainer.contributions_groups + features_importance = self._explainer.features_imp_groups + if mode == "global-local": + local_imp_lev1 = self._explainer.features_imp_groups_local_lev1 + local_imp_lev2 = self._explainer.features_imp_groups_local_lev2 + contributions = self._explainer.contributions_groups else: - features_importance = self.explainer.features_imp - contributions = self.explainer.contributions - - # classification - if self.explainer._case == "classification": - label_num, _, label_value = self.explainer.check_label_name(label) - global_feat_imp = get_feature_importance_page(features_importance[label_num], page, max_features) - if selection is not None: - subset_feat_imp = self.explainer.backend.get_global_features_importance( - contributions=contributions[label_num], explain_data=self.explainer.explain_data, subset=selection - ) - else: - subset_feat_imp = None + features_importance = self._explainer.features_imp + if mode == "global-local": + local_imp_lev1 = self._explainer.features_imp_local_lev1 + local_imp_lev2 = self._explainer.features_imp_local_lev2 + contributions = self._explainer.contributions + + subtitle = "" + + # Classification case + if self._explainer._case == "classification": + label_num, _, label_value = self._explainer.check_label_name(label) + features_importance_case = features_importance[label_num] + contributions_case = contributions[label_num] subtitle = f"Response: {label_value}" - # regression - elif self.explainer._case == "regression": - global_feat_imp = get_feature_importance_page(features_importance, page, max_features) - if selection is not None: - subset_feat_imp = self.explainer.backend.get_global_features_importance( - contributions=contributions, explain_data=self.explainer.explain_data, subset=selection - ) - else: - subset_feat_imp = None - addnote = "" + # Regression case + elif self._explainer._case == "regression": + label_num = None + features_importance_case = features_importance + contributions_case = contributions + else: + raise ValueError("Invalid case. Case must be either 'classification' or 'regression'.") + + global_feat_imp = get_feature_importance_page(features_importance_case, page, max_features) + + if mode == "global-local": + local_imp_lev1, local_imp_lev2 = self._get_local_feature_importance( + global_feat_imp.index, local_imp_lev1, local_imp_lev2, label_num + ) + subset_feat_imp = self._get_subset_importance(contributions_case, selection) if subset_feat_imp is not None: subset_feat_imp = subset_feat_imp.reindex(global_feat_imp.index) - subset_feat_imp.index = subset_feat_imp.index.map(self.explainer.features_dict) + subset_feat_imp.index = subset_feat_imp.index.map(self._explainer.features_dict) if subset_feat_imp.dropna().shape[0] == 0: raise ValueError("selection argument doesn't return any row") + + addnote = self._build_additional_notes(subset_feat_imp, selection, max_features) + + # Map feature names + global_feat_imp.index = global_feat_imp.index.map(self._explainer.features_dict) + if mode == "global-local": + local_imp_lev1.index = local_imp_lev1.index.map(self._explainer.features_dict) + local_imp_lev2.index = local_imp_lev2.index.map(self._explainer.features_dict) + + # Format indices if display_groups is enabled + if display_groups: + global_feat_imp, local_imp_lev1, local_imp_lev2, subset_feat_imp = self._apply_bold_formatting( + global_feat_imp, local_imp_lev1, local_imp_lev2, subset_feat_imp, mode + ) + + # Generate and return the plot + return self._generate_feature_importance_plot( + mode, + global_feat_imp, + contributions_case, + local_imp_lev1, + local_imp_lev2, + subset_feat_imp, + title, + addnote, + subtitle, + width, + height, + file_name, + auto_open, + zoom, + normalize_by_nb_samples, + degree, + ) + + def _get_group_feature_importance(self, group_name): + """Retrieve the feature importance for a specific group of features.""" + if isinstance(self._explainer.features_imp, list): + return [ + label_feat_imp.loc[label_feat_imp.index.isin(self._explainer.features_groups[group_name])] + for label_feat_imp in self._explainer.features_imp + ] + return self._explainer.features_imp.loc[ + self._explainer.features_imp.index.isin(self._explainer.features_groups[group_name]) + ] + + def _get_local_feature_importance(self, indices, local_imp_lev1, local_imp_lev2, label_num=None): + """Retrieve local feature importance for global-local mode.""" + if label_num is not None: + local_imp_lev1 = local_imp_lev1[label_num].loc[indices] + local_imp_lev2 = local_imp_lev2[label_num].loc[indices] + else: + local_imp_lev1 = local_imp_lev1.loc[indices] + local_imp_lev2 = local_imp_lev2.loc[indices] + return local_imp_lev1, local_imp_lev2 + + def _get_subset_importance(self, contributions, selection): + """Retrieve feature importance for a subset of features, if specified.""" + if selection is not None: + return self._explainer.backend.get_global_features_importance( + contributions=contributions, explain_data=self._explainer.explain_data, subset=selection + ) + return None + + def _build_additional_notes(self, subset_feat_imp, selection, max_features): + """Generate additional notes to display in the plot.""" + addnote = "" + if subset_feat_imp is not None: subset_len = len(selection) - total_len = self.explainer.x_init.shape[0] + total_len = self._explainer.x_init.shape[0] addnote = add_text( [addnote, f"Subset length: {subset_len} ({int(np.round(100 * subset_len / total_len))}%)"], sep=" - " ) - if self.explainer.x_init.shape[1] >= max_features: - addnote = add_text([addnote, f"Total number of features: {int(self.explainer.x_init.shape[1])}"], sep=" - ") + if self._explainer.x_init.shape[1] >= max_features: + addnote = add_text( + [addnote, f"Total number of features: {int(self._explainer.x_init.shape[1])}"], sep=" - " + ) + return addnote - global_feat_imp.index = global_feat_imp.index.map(self.explainer.features_dict) - if display_groups: - # Bold font for groups of features - global_feat_imp.index = [ - ( - "" + str(f) - if self.explainer.inv_features_dict.get(f) in self.explainer.features_groups.keys() - else str(f) - ) - for f in global_feat_imp.index - ] - if subset_feat_imp is not None: - subset_feat_imp.index = [ - ( - "" + str(f) - if self.explainer.inv_features_dict.get(f) in self.explainer.features_groups.keys() - else str(f) - ) - for f in subset_feat_imp.index - ] + def _apply_bold_formatting(self, global_feat_imp, local_imp_lev1, local_imp_lev2, subset_feat_imp, mode): + """Apply bold formatting to feature names for feature groups.""" - fig = self.plot_features_import( - global_feat_imp, subset_feat_imp, title, addnote, subtitle, width, height, file_name, auto_open, zoom - ) - return fig + def bold_feature_name(index): + feature_name = str(index) + if self._explainer.inv_features_dict.get(index) in self._explainer.features_groups: + return f"{feature_name}" + return feature_name + + global_feat_imp.index = [bold_feature_name(f) for f in global_feat_imp.index] + + if mode == "global-local": + local_imp_lev1.index = [bold_feature_name(f) for f in global_feat_imp.index] + local_imp_lev2.index = [bold_feature_name(f) for f in global_feat_imp.index] + if subset_feat_imp is not None: + subset_feat_imp.index = [bold_feature_name(f) for f in subset_feat_imp.index] + return global_feat_imp, local_imp_lev1, local_imp_lev2, subset_feat_imp - def plot_line_comparison( + def _generate_feature_importance_plot( + self, + mode, + global_feat_imp, + contributions_case, + local_imp_lev1=None, + local_imp_lev2=None, + subset_feat_imp=None, + title="", + addnote="", + subtitle="", + width=900, + height=500, + file_name=None, + auto_open=False, + zoom=False, + normalize_by_nb_samples=False, + degree="slider", + ): + """Generate the feature importance plot based on the mode.""" + if mode == "global": + return self._plot_features_import( + global_feat_imp, subset_feat_imp, title, addnote, subtitle, width, height, file_name, auto_open, zoom + ) + elif mode == "global-local": + feat_imp = {"global": global_feat_imp, "semi-local": local_imp_lev1, "local": local_imp_lev2} + return self._plot_local_features_import( + feat_imp, + title, + addnote, + subtitle, + width, + height, + file_name, + auto_open, + zoom, + ) + elif mode == "cumulative": + return self._plot_feature_contributions_cumulative( + global_feat_imp, + contributions_case, + title, + addnote, + subtitle, + width, + height, + normalize_by_nb_samples, + degree, + file_name, + auto_open, + zoom, + ) + else: + raise ValueError("Invalid value for mode. It must be 'global', 'global-local', or 'cumulative'.") + + def _plot_line_comparison( self, index, feature_values, @@ -2073,14 +2562,14 @@ def compare_plot( line_reference = [] if index is not None: for ident in index: - if ident in self.explainer.x_init.index: + if ident in self._explainer.x_init.index: line_reference.append(ident) elif row_num is not None: line_reference = [ - self.explainer.x_init.index.values[row_nb_reference] + self._explainer.x_init.index.values[row_nb_reference] for row_nb_reference in row_num - if self.explainer.x_init.index.values[row_nb_reference] in self.explainer.x_init.index + if self._explainer.x_init.index.values[row_nb_reference] in self._explainer.x_init.index ] subtitle = "" @@ -2088,15 +2577,15 @@ def compare_plot( raise ValueError("No matching entry for index") # Classification case - if self.explainer._case == "classification": + if self._explainer._case == "classification": if label is None: label = -1 - label_num, _, label_value = self.explainer.check_label_name(label) - contrib = self.explainer.contributions[label_num] + label_num, _, label_value = self._explainer.check_label_name(label) + contrib = self._explainer.contributions[label_num] if show_predict: - preds = [self.local_pred(line, label_num) for line in line_reference] + preds = [self._local_pred(line, label_num) for line in line_reference] subtitle = ( f"Response: {label_value} - " + "Probas: " @@ -2106,11 +2595,11 @@ def compare_plot( ) # Regression case - elif self.explainer._case == "regression": - contrib = self.explainer.contributions + elif self._explainer._case == "regression": + contrib = self._explainer.contributions if show_predict: - preds = [self.local_pred(line) for line in line_reference] + preds = [self._local_pred(line) for line in line_reference] subtitle = "Predictions: " + " ; ".join( [str(id) + ": " + str(round(pred, 2)) + "" for id, pred in zip(line_reference, preds)] ) @@ -2122,13 +2611,13 @@ def compare_plot( # Well labels if available feature_values = [0] * len(contrib.columns) - if hasattr(self.explainer, "columns_dict"): + if hasattr(self._explainer, "columns_dict"): for i, name in enumerate(contrib.columns): - feature_name = self.explainer.features_dict[name] + feature_name = self._explainer.features_dict[name] feature_values[i] = feature_name - preds = [self.explainer.x_init.loc[id] for id in line_reference] - dict_features = self.explainer.inv_features_dict + preds = [self._explainer.x_init.loc[id] for id in line_reference] + dict_features = self._explainer.inv_features_dict iteration_list = list(zip(new_contrib, feature_values)) iteration_list.sort(key=lambda x: maximum_difference_sort_value(x), reverse=True) @@ -2136,7 +2625,7 @@ def compare_plot( iteration_list = iteration_list[::-1] new_contrib, feature_values = list(zip(*iteration_list)) - fig = self.plot_line_comparison( + fig = self._plot_line_comparison( line_reference, feature_values, new_contrib, @@ -2348,10 +2837,10 @@ def _select_indices_interactions_plot(self, selection, max_points): # interaction_selection attribute is used to store already computed indices of interaction_values if hasattr(self, "interaction_selection"): list_ind = self.interaction_selection - elif self.explainer.x_init.shape[0] <= max_points: - list_ind = self.explainer.x_init.index.tolist() + elif self._explainer.x_init.shape[0] <= max_points: + list_ind = self._explainer.x_init.index.tolist() else: - list_ind = random.sample(self.explainer.x_init.index.tolist(), max_points) + list_ind = random.sample(self._explainer.x_init.index.tolist(), max_points) addnote = "Length of random Subset : " elif isinstance(selection, list): if len(selection) <= max_points: @@ -2364,7 +2853,7 @@ def _select_indices_interactions_plot(self, selection, max_points): list_ind = random.sample(selection, max_points) addnote = "Length of random Subset : " else: - ValueError("parameter selection must be a list") + raise ValueError("parameter selection must be a list") self.interaction_selection = list_ind return list_ind, addnote @@ -2421,36 +2910,36 @@ def interactions_plot( if not (isinstance(col1, (str, int)) or isinstance(col2, (str, int))): raise ValueError("parameters col1 and col2 must be string or int.") - col_id1 = self.explainer.check_features_name([col1])[0] - col_name1 = self.explainer.columns_dict[col_id1] + col_id1 = self._explainer.check_features_name([col1])[0] + col_name1 = self._explainer.columns_dict[col_id1] - col_id2 = self.explainer.check_features_name([col2])[0] - col_name2 = self.explainer.columns_dict[col_id2] + col_id2 = self._explainer.check_features_name([col2])[0] + col_name2 = self._explainer.columns_dict[col_id2] - col_value_count1 = self.explainer.features_desc[col_name1] + col_value_count1 = self._explainer.features_desc[col_name1] list_ind, addnote = self._select_indices_interactions_plot(selection=selection, max_points=max_points) if addnote is not None: addnote = add_text( - [addnote, f"{len(list_ind)} ({int(np.round(100 * len(list_ind) / self.explainer.x_init.shape[0]))}%)"], + [addnote, f"{len(list_ind)} ({int(np.round(100 * len(list_ind) / self._explainer.x_init.shape[0]))}%)"], sep="", ) # Subset - if self.explainer.postprocessing_modifications: - feature_values1 = self.explainer.x_contrib_plot.loc[list_ind, col_name1].to_frame() - feature_values2 = self.explainer.x_contrib_plot.loc[list_ind, col_name2].to_frame() + if self._explainer.postprocessing_modifications: + feature_values1 = self._explainer.x_contrib_plot.loc[list_ind, col_name1].to_frame() + feature_values2 = self._explainer.x_contrib_plot.loc[list_ind, col_name2].to_frame() else: - feature_values1 = self.explainer.x_init.loc[list_ind, col_name1].to_frame() - feature_values2 = self.explainer.x_init.loc[list_ind, col_name2].to_frame() + feature_values1 = self._explainer.x_init.loc[list_ind, col_name1].to_frame() + feature_values2 = self._explainer.x_init.loc[list_ind, col_name2].to_frame() - interaction_values = self.explainer.get_interaction_values(selection=list_ind)[:, col_id1, col_id2] + interaction_values = self._explainer.get_interaction_values(selection=list_ind)[:, col_id1, col_id2] if col_id1 != col_id2: interaction_values = interaction_values * 2 # add break line to X label if necessary - max_len_by_row = max([round(50 / self.explainer.features_desc[feature_values1.columns.values[0]]), 8]) + max_len_by_row = max([round(50 / self._explainer.features_desc[feature_values1.columns.values[0]]), 8]) feature_values1.iloc[:, 0] = feature_values1.iloc[:, 0].apply( add_line_break, args=( @@ -2542,7 +3031,7 @@ def top_interactions_plot( list_ind, addnote = self._select_indices_interactions_plot(selection=selection, max_points=max_points) - interaction_values = self.explainer.get_interaction_values(selection=list_ind) + interaction_values = self._explainer.get_interaction_values(selection=list_ind) sorted_top_features_indices = compute_sorted_variables_interactions_list_indices(interaction_values) @@ -2553,8 +3042,8 @@ def top_interactions_plot( id0, id1 = ids fig_one_interaction = self.interactions_plot( - col1=self.explainer.columns_dict[id0], - col2=self.explainer.columns_dict[id1], + col1=self._explainer.columns_dict[id0], + col2=self._explainer.columns_dict[id1], selection=selection, violin_maxf=violin_maxf, max_points=max_points, @@ -2581,7 +3070,7 @@ def generate_title_dict(col_name1, col_name2, addnote): fig.layout.coloraxis.colorscale = self._style_dict["interactions_col_scale"] fig.update_layout( - xaxis_title=self.explainer.columns_dict[sorted_top_features_indices[0][0]], + xaxis_title=self._explainer.columns_dict[sorted_top_features_indices[0][0]], yaxis_title="Shap interaction value", updatemenus=[ dict( @@ -2589,7 +3078,7 @@ def generate_title_dict(col_name1, col_name2, addnote): buttons=list( [ dict( - label=f"{self.explainer.columns_dict[i]} - {self.explainer.columns_dict[j]}", + label=f"{self._explainer.columns_dict[i]} - {self._explainer.columns_dict[j]}", method="update", args=[ { @@ -2602,17 +3091,17 @@ def generate_title_dict(col_name1, col_name2, addnote): { "xaxis": { "title": { - **{"text": self.explainer.columns_dict[i]}, + **{"text": self._explainer.columns_dict[i]}, **self._style_dict["dict_xaxis"], } }, - "legend": {"title": {"text": self.explainer.columns_dict[j]}}, + "legend": {"title": {"text": self._explainer.columns_dict[j]}}, "coloraxis": { - "colorbar": {"title": {"text": self.explainer.columns_dict[j]}}, + "colorbar": {"title": {"text": self._explainer.columns_dict[j]}}, "colorscale": fig.layout.coloraxis.colorscale, }, "title": generate_title_dict( - self.explainer.columns_dict[i], self.explainer.columns_dict[j], addnote + self._explainer.columns_dict[i], self._explainer.columns_dict[j], addnote ), }, ], @@ -2644,8 +3133,8 @@ def generate_title_dict(col_name1, col_name2, addnote): self._update_interactions_fig( fig=fig, - col_name1=self.explainer.columns_dict[sorted_top_features_indices[0][0]], - col_name2=self.explainer.columns_dict[sorted_top_features_indices[0][1]], + col_name1=self._explainer.columns_dict[sorted_top_features_indices[0][0]], + col_name2=self._explainer.columns_dict[sorted_top_features_indices[0][1]], addnote=addnote, width=width, height=height, @@ -2759,7 +3248,7 @@ def cluster_corr(corr, degree, inplace=False): if df is None: # Use x_init by default - df = self.explainer.x_init.copy() + df = self._explainer.x_init.copy() if optimized: categorical_columns = df.select_dtypes(include=["object", "category"]).columns @@ -2809,8 +3298,8 @@ def cluster_corr(corr, degree, inplace=False): coloraxis="coloraxis", text=[ [ - f"Feature 1: {self.explainer.features_dict.get(y, y)}
" - f"Feature 2: {self.explainer.features_dict.get(x, x)}" + f"Feature 1: {self._explainer.features_dict.get(y, y)}
" + f"Feature 2: {self._explainer.features_dict.get(x, x)}" for x in list_features ] for y in list_features @@ -2839,8 +3328,8 @@ def cluster_corr(corr, degree, inplace=False): coloraxis="coloraxis", text=[ [ - f"Feature 1: {self.explainer.features_dict.get(y, y)}
" - f"Feature 2: {self.explainer.features_dict.get(x, x)}" + f"Feature 1: {self._explainer.features_dict.get(y, y)}
" + f"Feature 2: {self._explainer.features_dict.get(x, x)}" for x in list_features ] for y in list_features @@ -2872,7 +3361,7 @@ def cluster_corr(corr, degree, inplace=False): return fig - def plot_amplitude_vs_stability(self, mean_variability, mean_amplitude, column_names, file_name, auto_open): + def _plot_amplitude_vs_stability(self, mean_variability, mean_amplitude, column_names, file_name, auto_open): """ Intermediate function used to display the stability plot when plot_type is "none" Parameters @@ -2897,7 +3386,7 @@ def plot_amplitude_vs_stability(self, mean_variability, mean_amplitude, column_n + "
(standard deviation / mean)
" ) yaxis_title = "Importance
(Average contributions)
" - col_scale, _, _ = self.tuning_colorscale(pd.DataFrame(mean_amplitude)) + col_scale, _, _ = self._tuning_colorscale(pd.DataFrame(mean_amplitude)) hv_text = [ f"Feature: {col}
Importance: {y}
Variability: {x}" for col, x, y in zip(column_names, mean_variability, mean_amplitude) @@ -2934,7 +3423,7 @@ def plot_amplitude_vs_stability(self, mean_variability, mean_amplitude, column_n ) return fig - def plot_stability_distribution( + def _plot_stability_distribution( self, variability, plot_type, mean_amplitude, dataset, column_names, file_name, auto_open ): """ @@ -2971,7 +3460,7 @@ def plot_stability_distribution( var_df = var_df[column_names[mean_amplitude.argsort()]] # Add colorscale - col_scale, _, _ = self.tuning_colorscale(pd.DataFrame(mean_amplitude)) + col_scale, _, _ = self._tuning_colorscale(pd.DataFrame(mean_amplitude)) color_list = mean_amplitude_normalized.tolist() color_list.sort() color_list = [next(pair[1] for pair in col_scale if x <= pair[0]) for x in color_list] @@ -3154,17 +3643,17 @@ def local_neighbors_plot(self, index, max_features=10, file_name=None, auto_open fig The figure that will be displayed """ - assert index in self.explainer.x_init.index, "index must exist in pandas dataframe" + assert index in self._explainer.x_init.index, "index must exist in pandas dataframe" - self.explainer.compute_features_stability([index]) + self._explainer.compute_features_stability([index]) - column_names = np.array([self.explainer.features_dict.get(x) for x in self.explainer.x_init.columns]) + column_names = np.array([self._explainer.features_dict.get(x) for x in self._explainer.x_init.columns]) def ordinal(n): return "%d%s" % (n, "tsnrhtdd"[(math.floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4]) # Compute explanations for instance and neighbors - g = self.explainer.local_neighbors["norm_shap"] + g = self._explainer.local_neighbors["norm_shap"] # Reorder indices based on absolute values of the 1st row (i.e. the instance) in descending order inds = np.flip(np.abs(g[0, :]).argsort()) @@ -3291,16 +3780,16 @@ def stability_plot( """ # Sampling if selection is None: - if self.explainer.x_init.shape[0] <= max_points: - list_ind = self.explainer.x_init.index.tolist() + if self._explainer.x_init.shape[0] <= max_points: + list_ind = self._explainer.x_init.index.tolist() else: - list_ind = random.sample(self.explainer.x_init.index.tolist(), max_points) + list_ind = random.sample(self._explainer.x_init.index.tolist(), max_points) # By default, don't compute calculation if it has already been done - if (self.explainer.features_stability is None) or self.last_stability_selection or force: - self.explainer.compute_features_stability(list_ind) + if (self._explainer.features_stability is None) or self._last_stability_selection or force: + self._explainer.compute_features_stability(list_ind) else: print("Computed values from previous call are used") - self.last_stability_selection = False + self._last_stability_selection = False elif isinstance(selection, list): if len(selection) == 1: raise ValueError("Selection must include multiple points") @@ -3309,22 +3798,24 @@ def stability_plot( f"Size of selection is bigger than max_points (default: {max_points}).\ Computation time might be affected" ) - self.explainer.compute_features_stability(selection) - self.last_stability_selection = True + self._explainer.compute_features_stability(selection) + self._last_stability_selection = True else: raise ValueError("Parameter selection must be a list") - column_names = np.array([self.explainer.features_dict.get(x) for x in self.explainer.x_init.columns]) + column_names = np.array([self._explainer.features_dict.get(x) for x in self._explainer.x_init.columns]) - variability = self.explainer.features_stability["variability"] - amplitude = self.explainer.features_stability["amplitude"] + variability = self._explainer.features_stability["variability"] + amplitude = self._explainer.features_stability["amplitude"] mean_variability = variability.mean(axis=0) mean_amplitude = amplitude.mean(axis=0) # Plot 1 : only show average variability on y-axis if distribution not in ["boxplot", "violin"]: - fig = self.plot_amplitude_vs_stability(mean_variability, mean_amplitude, column_names, file_name, auto_open) + fig = self._plot_amplitude_vs_stability( + mean_variability, mean_amplitude, column_names, file_name, auto_open + ) # Plot 2 : Show distribution of variability else: @@ -3336,10 +3827,10 @@ def stability_plot( variability = variability[:, keep] mean_amplitude = mean_amplitude[keep] - dataset = self.explainer.x_init.iloc[:, keep] + dataset = self._explainer.x_init.iloc[:, keep] column_names = column_names[keep] - fig = self.plot_stability_distribution( + fig = self._plot_stability_distribution( variability, distribution, mean_amplitude, dataset, column_names, file_name, auto_open ) @@ -3384,13 +3875,13 @@ def compacity_plot( """ # Sampling if selection is None: - if self.explainer.x_init.shape[0] <= max_points: - list_ind = self.explainer.x_init.index.tolist() + if self._explainer.x_init.shape[0] <= max_points: + list_ind = self._explainer.x_init.index.tolist() else: - list_ind = random.sample(self.explainer.x_init.index.tolist(), max_points) + list_ind = random.sample(self._explainer.x_init.index.tolist(), max_points) # By default, don't compute calculation if it has already been done - if (self.explainer.features_compacity is None) or self.last_compacity_selection or force: - self.explainer.compute_features_compacity(list_ind, 1 - approx, nb_features) + if (self._explainer.features_compacity is None) or self.last_compacity_selection or force: + self._explainer.compute_features_compacity(list_ind, 1 - approx, nb_features) else: print("Computed values from previous call are used") self.last_compacity_selection = False @@ -3400,13 +3891,13 @@ def compacity_plot( f"Size of selection is bigger than max_points (default: {max_points}).\ Computation time might be affected" ) - self.explainer.compute_features_compacity(selection, 1 - approx, nb_features) - self.last_compacity_selection = True + self._explainer.compute_features_compacity(selection, 1 - approx, nb_features) + self._last_compacity_selection = True else: raise ValueError("Parameter selection must be a list") - features_needed = self.explainer.features_compacity["features_needed"] - distance_reached = self.explainer.features_compacity["distance_reached"] + features_needed = self._explainer.features_compacity["features_needed"] + distance_reached = self._explainer.features_compacity["distance_reached"] # Make plots fig = make_subplots( @@ -3531,17 +4022,17 @@ def scatter_plot_prediction( auto_open: bool (default=False) open automatically the plot """ - if self.explainer.y_target is not None: + if self._explainer.y_target is not None: # Sampling - list_ind, addnote = self.explainer.plot._subset_sampling(selection, max_points) + list_ind, addnote = self._explainer.plot._subset_sampling(selection, max_points) # Classification Case - if self.explainer._case == "classification": - fig, subtitle = self.explainer.plot._prediction_classification_plot(list_ind, label) + if self._explainer._case == "classification": + fig, subtitle = self._explainer.plot._prediction_classification_plot(list_ind, label) # Regression Case - elif self.explainer._case == "regression": - fig, subtitle = self.explainer.plot._prediction_regression_plot(list_ind) + elif self._explainer._case == "regression": + fig, subtitle = self._explainer.plot._prediction_regression_plot(list_ind) # Add traces, title and template title = "True Values Vs Predicted Values" @@ -3615,17 +4106,17 @@ def _prediction_classification_plot( """ fig = go.Figure() - label_num, _, label_value = self.explainer.check_label_name(label) + label_num, _, label_value = self._explainer.check_label_name(label) # predict proba Color scale - if self.explainer.proba_values is not None: + if self._explainer.proba_values is not None: # Assign proba values of the target - df_proba_target = self.explainer.proba_values.copy() + df_proba_target = self._explainer.proba_values.copy() df_proba_target["proba_target"] = df_proba_target.iloc[:, label_num] proba_values = df_proba_target[["proba_target"]] # Proba subset: proba_values = proba_values.loc[list_ind, :] - target = self.explainer.y_target.loc[list_ind, :] - y_pred = self.explainer.y_pred.loc[list_ind, :] + target = self._explainer.y_target.loc[list_ind, :] + y_pred = self._explainer.y_pred.loc[list_ind, :] df_pred = pd.concat( [proba_values.reset_index(), y_pred.reset_index(drop=True), target.reset_index(drop=True)], axis=1 ) @@ -3706,13 +4197,13 @@ def _prediction_classification_plot( ) fig.update_layout(violingap=0, violinmode="overlay") - if self.explainer.label_dict is not None: + if self._explainer.label_dict is not None: fig.update_xaxes( tickmode="array", tickvals=list(df_pred["target"].unique()), - ticktext=list(df_pred["target"].apply(lambda x: self.explainer.label_dict[x]).unique()), + ticktext=list(df_pred["target"].apply(lambda x: self._explainer.label_dict[x]).unique()), ) - if self.explainer.label_dict is None: + if self._explainer.label_dict is None: fig.update_xaxes(tickvals=sorted(list(df_pred["target"].unique()))) return fig, subtitle @@ -3736,9 +4227,9 @@ def _prediction_regression_plot( fig = go.Figure() subtitle = None - prediction_error = self.explainer.prediction_error + prediction_error = self._explainer.prediction_error if prediction_error is not None: - if (self.explainer.y_target == 0).any().iloc[0]: + if (self._explainer.y_target == 0).any().iloc[0]: subtitle = "Prediction Error = abs(True Values - Predicted Values)" else: subtitle = "Prediction Error = abs(True Values - Predicted Values) / True Values" @@ -3747,9 +4238,9 @@ def _prediction_regression_plot( equal_bins = np.unique(equal_bins) bins_list = [i for i in equal_bins] values = pd.DataFrame(pd.cut([val[0] for val in prediction_error.values], bins=bins_list, labels=False)) - col_scale, _, _ = self.tuning_colorscale(values, keep_90_pct=False) + col_scale, _, _ = self._tuning_colorscale(values, keep_90_pct=False) - y_target = self.explainer.y_target.loc[list_ind] + y_target = self._explainer.y_target.loc[list_ind] if len(y_target) > 500: lower_quantile = y_target.iloc[:, 0].quantile(0.005) upper_quantile = y_target.iloc[:, 0].quantile(0.995) @@ -3769,7 +4260,7 @@ def _prediction_regression_plot( y_target_values = y_target.values.flatten() - y_pred = self.explainer.y_pred.loc[y_target.index] + y_pred = self._explainer.y_pred.loc[y_target.index] prediction_error = np.array(prediction_error.loc[y_target.index]) feature_values_array = y_target_values @@ -3804,9 +4295,9 @@ def _prediction_regression_plot( fig.add_trace(density_plot) # round predict - if self.round_digit is None: - self.tuning_round_digit() - y_pred = y_pred.map(lambda x: round(x, self.round_digit)) + if self._round_digit is None: + self._tuning_round_digit() + y_pred = y_pred.map(lambda x: round(x, self._round_digit)) y_pred_flatten = y_pred.values.flatten() hv_text = [ @@ -3911,14 +4402,14 @@ def _no_selection_sampling(self, max_points, col, col_value_count, random_seed): """ Handles sampling when no specific selection is made. """ - if self.explainer.x_init.shape[0] <= max_points: - return self.explainer.x_init.index.tolist(), None + if self._explainer.x_init.shape[0] <= max_points: + return self._explainer.x_init.index.tolist(), None elif col is None: - selected_indices = random.sample(self.explainer.x_init.index.tolist(), max_points) + selected_indices = random.sample(self._explainer.x_init.index.tolist(), max_points) return selected_indices, "Length of random Subset: " else: selected_indices = self._intelligent_sampling( - self.explainer.x_init, max_points, col, col_value_count, random_seed + self._explainer.x_init, max_points, col, col_value_count, random_seed ) return selected_indices, "Length of smart Subset: " @@ -3932,7 +4423,7 @@ def _list_selection_sampling(self, selection, max_points, col, col_value_count, selected_indices = random.sample(selection, max_points) return selected_indices, "Length of random Subset: " else: - subset = self.explainer.x_init.loc[selection] + subset = self._explainer.x_init.loc[selection] selected_indices = self._intelligent_sampling(subset, max_points, col, col_value_count, random_seed) return selected_indices, "Length of smart Subset: " @@ -3968,5 +4459,5 @@ def _format_additional_note(self, selected_indices, additional_note): """ Formats the additional note with the length and percentage of the selected subset. """ - percentage = int(np.round(100 * len(selected_indices) / self.explainer.x_init.shape[0])) + percentage = int(np.round(100 * len(selected_indices) / self._explainer.x_init.shape[0])) return f"{additional_note}{len(selected_indices)} ({percentage}%)" diff --git a/shapash/explainer/smart_state.py b/shapash/explainer/smart_state.py index b9e78364..05d3b12a 100644 --- a/shapash/explainer/smart_state.py +++ b/shapash/explainer/smart_state.py @@ -301,11 +301,11 @@ def summarize(self, s_contrib, var_dict, x_sorted, mask, columns_dict, features_ """ return summarize(s_contrib, var_dict, x_sorted, mask, columns_dict, features_dict) - def compute_features_import(self, contributions): + def compute_features_import(self, contributions, norm=1): """ Compute a relative features importance, sum of absolute values - ​​of the contributions for each - features importance compute in base 100 + ​​of the contributions for each + features importance compute in base 100 Parameters ---------- contributions: pd.DataFrame @@ -317,7 +317,7 @@ def compute_features_import(self, contributions): feature importance, One row by feature, index of the serie = contributions.columns """ - return compute_features_import(contributions) + return compute_features_import(contributions, norm) def compute_grouped_contributions(self, contributions, features_groups): """ diff --git a/shapash/manipulation/filters.py b/shapash/manipulation/filters.py index 6d959558..0860cf7f 100644 --- a/shapash/manipulation/filters.py +++ b/shapash/manipulation/filters.py @@ -113,7 +113,11 @@ def cutoff_contributions(mask, k=10): pd.Dataframe Mask where only the k-top contributions are considered. """ - return mask.replace(False, np.nan).cumsum(axis=1).isin(range(1, k + 1)) + # Convert False values to np.nan explicitly without changing data type + mask_nan = mask.astype(float).replace(0, np.nan) + + # Compute the cumulative sum and check if the index is within the top-k + return mask_nan.cumsum(axis=1).isin(range(1, k + 1)) def combine_masks(masks_list): diff --git a/shapash/manipulation/summarize.py b/shapash/manipulation/summarize.py index 8d4d0437..94a3dfea 100644 --- a/shapash/manipulation/summarize.py +++ b/shapash/manipulation/summarize.py @@ -44,11 +44,11 @@ def summarize_el(dataframe, mask, prefix): return df_summarized_matrix -def compute_features_import(dataframe): +def compute_features_import(dataframe, norm=1): """ Compute a relative features importance, sum of absolute values - ​​of the contributions for each - features importance compute in base 100 + ​​of the contributions for each + features importance compute in base 100 Parameters ---------- dataframe: pd.DataFrame @@ -60,7 +60,7 @@ def compute_features_import(dataframe): feature importance One row by feature, index of the serie = dataframe.columns """ - feat_imp = dataframe.abs().sum().sort_values(ascending=True) + feat_imp = (((dataframe.abs() ** norm).sum()) ** (1 / norm)).sort_values(ascending=True) tot = feat_imp.sum() return feat_imp / tot diff --git a/shapash/style/colors.json b/shapash/style/colors.json index ddb9abd6..2a4595ba 100644 --- a/shapash/style/colors.json +++ b/shapash/style/colors.json @@ -28,9 +28,19 @@ "rgb(0, 70, 92)" ], "contrib_distribution": "rgb(211, 211, 211)", + "feature_contributions_cumulative": [ + "rgb(255, 55, 55)", + "rgb(255, 166, 0)", + "rgb(255, 200, 100)", + "rgb(55, 190, 255)", + "rgb(0, 0, 255)", + "rgb(0, 0, 0)" + ], "featureimp_bar": { "1": "rgba(0, 154, 203, 1)", - "2": "rgba(223, 103, 0, 0.8)" + "2": "rgba(223, 103, 0, 0.8)", + "3": "rgba(240, 195, 162, 0.8)", + "4": "rgba(245, 122, 0, 0.8)" }, "featureimp_groups": { "0": "rgb(10, 204, 143)", @@ -128,9 +138,19 @@ "rgb(255, 77, 7)" ], "contrib_distribution": "rgb(211, 211, 211)", + "feature_contributions_cumulative": [ + "rgb(255, 55, 55)", + "rgb(255, 166, 0)", + "rgb(255, 200, 100)", + "rgb(55, 190, 255)", + "rgb(0, 0, 255)", + "rgb(0, 0, 0)" + ], "featureimp_bar": { "1": "rgba(244, 192, 0, 1.0)", - "2": "rgba(52, 55, 54, 0.7)" + "2": "rgba(52, 55, 54, 0.7)", + "3": "rgba(103, 208, 255, 0.8)", + "4": "rgba(0, 98, 128, 0.8)" }, "featureimp_groups": { "0": "rgb(245, 133, 24)", diff --git a/shapash/style/style_utils.py b/shapash/style/style_utils.py index 59968fba..f50e7f82 100644 --- a/shapash/style/style_utils.py +++ b/shapash/style/style_utils.py @@ -100,8 +100,11 @@ def define_style(palette): style_dict["dict_featimp_colors"] = { 1: {"color": featureimp_bar[1], "line": {"color": palette["featureimp_line"], "width": 0.5}}, 2: {"color": featureimp_bar[2]}, + 3: {"color": featureimp_bar[3], "line": {"color": palette["featureimp_line"], "width": 0.5}}, + 4: {"color": featureimp_bar[4], "line": {"color": palette["featureimp_line"], "width": 0.5}}, } style_dict["featureimp_groups"] = convert_string_to_int_keys(palette["featureimp_groups"]) + style_dict["feature_contributions_cumulative"] = palette["feature_contributions_cumulative"] style_dict["init_contrib_colorscale"] = palette["contrib_colorscale"] style_dict["contrib_distribution"] = palette["contrib_distribution"] style_dict["violin_area_classif"] = convert_string_to_int_keys(palette["violin_area_classif"]) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index acbb119e..49c8fffa 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -2135,7 +2135,7 @@ def update_id_card(n_submit, label, sort_by, order, data, index): selected_contrib = get_id_card_contrib( self.explainer.data, index, self.explainer.features_dict, self.explainer.columns_dict, label_num ) - proba = self.explainer.plot.local_pred(index, label_num) + proba = self.explainer.plot._local_pred(index, label_num) title_contrib = f"Contribution: {label_value} ({proba.round(2):.2f})" _, _, predicted_label_value = self.explainer.check_label_name( selected_row.loc["_predict_", "feature_value"] diff --git a/tests/unit_tests/explainer/test_smart_plotter.py b/tests/unit_tests/explainer/test_smart_plotter.py index ffd3620a..f3b94ca7 100644 --- a/tests/unit_tests/explainer/test_smart_plotter.py +++ b/tests/unit_tests/explainer/test_smart_plotter.py @@ -111,6 +111,7 @@ def setUp(self): self.smart_explainer.proba_values = None self.smart_explainer.features_desc = dict(self.x_init.nunique()) self.smart_explainer.features_compacity = self.features_compacity + self.smart_explainer.inv_features_dict = {v: k for k, v in self.smart_explainer.features_dict.items()} def test_define_style_attributes(self): # clear style attributes @@ -123,7 +124,7 @@ def test_define_style_attributes(self): assert len(list(self.smart_explainer.plot._style_dict.keys())) > 0 @patch("shapash.explainer.smart_explainer.SmartExplainer.filter") - @patch("shapash.explainer.smart_plotter.SmartPlotter.local_pred") + @patch("shapash.explainer.smart_plotter.SmartPlotter._local_pred") def test_local_plot_1(self, local_pred, filter): """ Unit test Local plot 1 @@ -190,7 +191,7 @@ def test_local_plot_3(self, select_lines): @patch("shapash.explainer.smart_explainer.SmartExplainer.filter") @patch("shapash.explainer.smart_plotter.select_lines") - @patch("shapash.explainer.smart_plotter.SmartPlotter.local_pred") + @patch("shapash.explainer.smart_plotter.SmartPlotter._local_pred") def test_local_plot_4(self, local_pred, select_lines, filter): """ Unit test local plot 4 @@ -253,7 +254,7 @@ def test_local_plot_4(self, local_pred, select_lines, filter): @patch("shapash.explainer.smart_explainer.SmartExplainer.filter") @patch("shapash.explainer.smart_plotter.select_lines") - @patch("shapash.explainer.smart_plotter.SmartPlotter.local_pred") + @patch("shapash.explainer.smart_plotter.SmartPlotter._local_pred") def test_local_plot_5(self, local_pred, select_lines, filter): """ Unit test local plot 5 @@ -347,7 +348,7 @@ def test_local_plot_5(self, local_pred, select_lines, filter): @patch("shapash.explainer.smart_explainer.SmartExplainer.filter") @patch("shapash.explainer.smart_plotter.select_lines") - @patch("shapash.explainer.smart_plotter.SmartPlotter.local_pred") + @patch("shapash.explainer.smart_plotter.SmartPlotter._local_pred") def test_local_plot_groups_features(self, local_pred, select_lines, filter): """ Unit test local plot 6 for groups of features @@ -515,7 +516,7 @@ def test_local_plot_groups_features(self, local_pred, select_lines, filter): @patch("shapash.explainer.smart_explainer.SmartExplainer.filter") @patch("shapash.explainer.smart_plotter.select_lines") - @patch("shapash.explainer.smart_plotter.SmartPlotter.local_pred") + @patch("shapash.explainer.smart_plotter.SmartPlotter._local_pred") def test_local_plot_multi_index(self, local_pred, select_lines, filter): """ Unit test local plot multi index @@ -584,7 +585,7 @@ def test_get_selection(self): Unit test get selection """ line = ["person_A"] - output = self.smart_explainer.plot.get_selection(line, self.var_dict, self.x_sorted, self.contrib_sorted) + output = self.smart_explainer.plot._get_selection(line, self.var_dict, self.x_sorted, self.contrib_sorted) expected_output = np.array([0, 1]), np.array(["PhD", 34]), np.array([-3.4, 0.78]) assert len(output) == 3 assert np.array_equal(output[0], expected_output[0]) @@ -599,7 +600,7 @@ def test_apply_mask_one_line(self): var_dict = np.array([0, 1]) x_sorted = np.array(["PhD", 34]) contrib_sorted = np.array([-3.4, 0.78]) - output = self.smart_explainer.plot.apply_mask_one_line(line, var_dict, x_sorted, contrib_sorted) + output = self.smart_explainer.plot._apply_mask_one_line(line, var_dict, x_sorted, contrib_sorted) expected_output = np.array([0]), np.array(["PhD"]), np.array([-3.4]) assert len(output) == 3 assert np.array_equal(output[0], expected_output[0]) @@ -614,7 +615,7 @@ def test_check_masked_contributions_1(self): var_dict = ["X1", "X2"] x_val = ["PhD", 34] contrib = [-3.4, 0.78] - var_dict, x_val, contrib = self.smart_explainer.plot.check_masked_contributions(line, var_dict, x_val, contrib) + var_dict, x_val, contrib = self.smart_explainer.plot._check_masked_contributions(line, var_dict, x_val, contrib) expected_var_dict = ["X1", "X2"] expected_x_val = ["PhD", 34] expected_contrib = [-3.4, 0.78] @@ -633,7 +634,7 @@ def test_check_masked_contributions_2(self): self.smart_explainer.masked_contributions = pd.DataFrame( data=[[0.0, 2.5], [0.0, 1.6]], columns=["masked_neg", "masked_pos"], index=["person_A", "person_B"] ) - var_dict, x_val, contrib = self.smart_explainer.plot.check_masked_contributions(line, var_dict, x_val, contrib) + var_dict, x_val, contrib = self.smart_explainer.plot._check_masked_contributions(line, var_dict, x_val, contrib) expected_var_dict = ["X1", "X2", "Hidden Positive Contributions"] expected_x_val = ["PhD", 34, ""] expected_contrib = [-3.4, 0.78, 2.5] @@ -655,7 +656,7 @@ def test_plot_bar_chart_1(self): ) expected_output_fig = go.Figure(data=bars, layout=go.Layout(yaxis=dict(type="category"))) self.smart_explainer._case = "regression" - fig_output = self.smart_explainer.plot.plot_bar_chart("ind", var_dict, x_val, contributions) + fig_output = self.smart_explainer.plot._plot_bar_chart("ind", var_dict, x_val, contributions) for part in list(zip(fig_output.data, expected_output_fig.data)): assert part[0].x == part[1].x assert part[0].y == part[1].y @@ -678,7 +679,7 @@ def test_plot_bar_chart_2(self): expected_output_fig = go.Figure(data=bars, layout=go.Layout(yaxis=dict(type="category"))) self.smart_explainer._case = "regression" - fig_output = self.smart_explainer.plot.plot_bar_chart("ind", var_dict, x_val, contributions) + fig_output = self.smart_explainer.plot._plot_bar_chart("ind", var_dict, x_val, contributions) for part in list(zip(fig_output.data, expected_output_fig.data)): assert part[0].x == part[1].x assert part[0].y == part[1].y @@ -1120,7 +1121,7 @@ def test_plot_features_import_1(self): Unit test plot features import 1 """ serie1 = pd.Series([0.131, 0.51], index=["col1", "col2"]) - output = self.smart_explainer.plot.plot_features_import(serie1) + output = self.smart_explainer.plot._plot_features_import(serie1) data = go.Bar(x=serie1, y=serie1.index, name="Global", orientation="h") expected_output = go.Figure(data=data) @@ -1135,7 +1136,7 @@ def test_plot_features_import_2(self): """ serie1 = pd.Series([0.131, 0.51], index=["col1", "col2"]) serie2 = pd.Series([0.33, 0.11], index=["col1", "col2"]) - output = self.smart_explainer.plot.plot_features_import(serie1, serie2) + output = self.smart_explainer.plot._plot_features_import(serie1, serie2) data1 = go.Bar(x=serie1, y=serie1.index, name="Global", orientation="h") data2 = go.Bar(x=serie2, y=serie2.index, name="Subset", orientation="h") expected_output = go.Figure(data=[data2, data1]) @@ -1154,7 +1155,7 @@ def test_features_importance_1(self): """ xpl = self.smart_explainer xpl.explain_data = None - output = xpl.plot.features_importance(selection=["person_A", "person_B"]) + output = xpl.plot.features_importance(selection=["person_A", "person_B"], zoom=True) data1 = go.Bar(x=np.array([0.2296, 0.7704]), y=np.array(["Age", "Education"]), name="Subset", orientation="h") @@ -1171,6 +1172,31 @@ def test_features_importance_1(self): assert output.data[1].name == expected_output.data[1].name assert output.data[1].orientation == expected_output.data[1].orientation + def test_features_importance_cumulative_1(self): + """ + Unit test features importance cumulative 1 + """ + xpl = self.smart_explainer + xpl.explain_data = None + output = xpl.plot.features_importance(mode="cumulative", selection=["person_A", "person_B"], zoom=True) + + assert len(output.data) == 2 + assert output.data[0].type == "scatter" + assert output.data[1].type == "scatter" + + def test_features_importance_local_1(self): + """ + Unit test features importance local 1 + """ + xpl = self.smart_explainer + xpl.explain_data = None + output = xpl.plot.features_importance(mode="global-local", selection=["person_A", "person_B"], zoom=True) + + assert len(output.data) == 3 + assert output.data[0].type == "bar" + assert output.data[1].type == "bar" + assert output.data[2].type == "bar" + def test_features_importance_2(self): """ Unit test features importance 2 @@ -1199,6 +1225,41 @@ def test_features_importance_2(self): assert output.data[1].name == expected_output.data[1].name assert output.data[1].orientation == expected_output.data[1].orientation + def test_features_importance_cumulative_2(self): + """ + Unit test features importance cumulative 2 + """ + xpl = self.smart_explainer + # regression + xpl.contributions = self.contrib1 + xpl.backend.state = SmartState() + xpl.explain_data = None + xpl._case = "regression" + xpl.state = SmartState() + output = xpl.plot.features_importance(mode="cumulative", selection=["person_A", "person_B"]) + + assert len(output.data) == 2 + assert output.data[0].type == "scatter" + assert output.data[1].type == "scatter" + + def test_features_importance_local_2(self): + """ + Unit test features importance local 2 + """ + xpl = self.smart_explainer + # regression + xpl.contributions = self.contrib1 + xpl.backend.state = SmartState() + xpl.explain_data = None + xpl._case = "regression" + xpl.state = SmartState() + output = xpl.plot.features_importance(mode="global-local", selection=["person_A", "person_B"]) + + assert len(output.data) == 3 + assert output.data[0].type == "bar" + assert output.data[1].type == "bar" + assert output.data[2].type == "bar" + def test_features_importance_3(self): """ Unit test features importance for groups of features @@ -1248,6 +1309,103 @@ def test_features_importance_3(self): assert output.data[0].name == expected_output.data[0].name assert output.data[0].orientation == expected_output.data[0].orientation + def test_features_importance_cumulative_3(self): + """ + Unit test features importance cumulative for groups of features + """ + x_init = pd.DataFrame( + data=np.array([["PhD", 34, 1], ["Master", 27, 0]]), + columns=["X1", "X2", "X3"], + index=["person_A", "person_B"], + ) + + contrib = pd.DataFrame( + data=np.array([[-3.4, 0.78, 1.2], [1.2, 3.6, -0.3]]), + columns=["X1", "X2", "X3"], + index=["person_A", "person_B"], + ) + + smart_explainer = SmartExplainer(model=self.model) + smart_explainer.x_encoded = x_init + smart_explainer.x_init = x_init + smart_explainer.postprocessing_modifications = False + smart_explainer.features_imp_groups = None + smart_explainer.features_imp = None + smart_explainer.features_groups = {"group0": ["X1", "X2"]} + smart_explainer.contributions = [contrib, -contrib] + smart_explainer.features_dict = {"X1": "X1", "X2": "X2", "X3": "X3", "group0": "group0"} + smart_explainer.inv_features_dict = {"X1": "X1", "X2": "X2", "X3": "X3", "group0": "group0"} + smart_explainer.model = self.smart_explainer.model + smart_explainer._case, smart_explainer._classes = check_model(self.smart_explainer.model) + smart_explainer.backend = ShapBackend(model=self.smart_explainer.model) + smart_explainer.backend.state = MultiDecorator(SmartState()) + smart_explainer.explain_data = None + smart_explainer.state = MultiDecorator(SmartState()) + smart_explainer.contributions_groups = smart_explainer.state.compute_grouped_contributions( + smart_explainer.contributions, smart_explainer.features_groups + ) + smart_explainer.features_imp_groups = smart_explainer.state.compute_features_import( + smart_explainer.contributions_groups + ) + + output = smart_explainer.plot.features_importance(mode="cumulative") + + assert len(output.data) == 2 + assert output.data[0].type == "scatter" + assert output.data[1].type == "scatter" + + def test_features_importance_local_3(self): + """ + Unit test features importance local for groups of features + """ + x_init = pd.DataFrame( + data=np.array([["PhD", 34, 1], ["Master", 27, 0]]), + columns=["X1", "X2", "X3"], + index=["person_A", "person_B"], + ) + + contrib = pd.DataFrame( + data=np.array([[-3.4, 0.78, 1.2], [1.2, 3.6, -0.3]]), + columns=["X1", "X2", "X3"], + index=["person_A", "person_B"], + ) + + smart_explainer = SmartExplainer(model=self.model) + smart_explainer.x_encoded = x_init + smart_explainer.x_init = x_init + smart_explainer.postprocessing_modifications = False + smart_explainer.features_imp_groups = None + smart_explainer.features_imp = None + smart_explainer.features_groups = {"group0": ["X1", "X2"]} + smart_explainer.contributions = [contrib, -contrib] + smart_explainer.features_dict = {"X1": "X1", "X2": "X2", "X3": "X3", "group0": "group0"} + smart_explainer.inv_features_dict = {"X1": "X1", "X2": "X2", "X3": "X3", "group0": "group0"} + smart_explainer.model = self.smart_explainer.model + smart_explainer._case, smart_explainer._classes = check_model(self.smart_explainer.model) + smart_explainer.backend = ShapBackend(model=self.smart_explainer.model) + smart_explainer.backend.state = MultiDecorator(SmartState()) + smart_explainer.explain_data = None + smart_explainer.state = MultiDecorator(SmartState()) + smart_explainer.contributions_groups = smart_explainer.state.compute_grouped_contributions( + smart_explainer.contributions, smart_explainer.features_groups + ) + smart_explainer.features_imp_groups = smart_explainer.state.compute_features_import( + smart_explainer.contributions_groups + ) + smart_explainer.features_imp_groups_local_lev1 = smart_explainer.state.compute_features_import( + smart_explainer.contributions_groups, norm=3 + ) + smart_explainer.features_imp_groups_local_lev2 = smart_explainer.state.compute_features_import( + smart_explainer.contributions_groups, norm=7 + ) + + output = smart_explainer.plot.features_importance(mode="global-local") + + assert len(output.data) == 3 + assert output.data[0].type == "bar" + assert output.data[1].type == "bar" + assert output.data[2].type == "bar" + def test_features_importance_4(self): """ Unit test features importance for groups of features when displaying a group @@ -1297,12 +1455,103 @@ def test_features_importance_4(self): assert output.data[0].name == expected_output.data[0].name assert output.data[0].orientation == expected_output.data[0].orientation + def test_features_importance_cumulative_4(self): + """ + Unit test features importance cumulative for groups of features when displaying a group + """ + x_init = pd.DataFrame( + data=np.array([["PhD", 34, 1], ["Master", 27, 0]]), + columns=["X1", "X2", "X3"], + index=["person_A", "person_B"], + ) + + contrib = pd.DataFrame( + data=np.array([[-3.4, 0.78, 1.2], [1.2, 3.6, -0.3]]), + columns=["X1", "X2", "X3"], + index=["person_A", "person_B"], + ) + + smart_explainer = SmartExplainer(model=self.model) + smart_explainer.x_encoded = x_init + smart_explainer.x_init = x_init + smart_explainer.postprocessing_modifications = False + smart_explainer.features_imp_groups = None + smart_explainer.features_imp = None + smart_explainer.features_groups = {"group0": ["X1", "X2"]} + smart_explainer.contributions = [contrib, -contrib] + smart_explainer.features_dict = {"X1": "X1", "X2": "X2", "X3": "X3", "group0": "group0"} + smart_explainer.inv_features_dict = {"X1": "X1", "X2": "X2", "X3": "X3", "group0": "group0"} + smart_explainer.model = self.smart_explainer.model + smart_explainer.backend = ShapBackend(model=self.smart_explainer.model) + smart_explainer.backend.state = MultiDecorator(SmartState()) + smart_explainer.explain_data = None + smart_explainer._case, smart_explainer._classes = check_model(self.smart_explainer.model) + smart_explainer.state = smart_explainer.backend.state + smart_explainer.contributions_groups = smart_explainer.state.compute_grouped_contributions( + smart_explainer.contributions, smart_explainer.features_groups + ) + smart_explainer.features_imp_groups = smart_explainer.state.compute_features_import( + smart_explainer.contributions_groups + ) + + output = smart_explainer.plot.features_importance(mode="cumulative", group_name="group0") + + assert len(output.data) == 2 + assert output.data[0].type == "scatter" + assert output.data[1].type == "scatter" + + def test_features_importance_local_4(self): + """ + Unit test features importance local for groups of features when displaying a group + """ + x_init = pd.DataFrame( + data=np.array([["PhD", 34, 1], ["Master", 27, 0]]), + columns=["X1", "X2", "X3"], + index=["person_A", "person_B"], + ) + + contrib = pd.DataFrame( + data=np.array([[-3.4, 0.78, 1.2], [1.2, 3.6, -0.3]]), + columns=["X1", "X2", "X3"], + index=["person_A", "person_B"], + ) + + smart_explainer = SmartExplainer(model=self.model) + smart_explainer.x_encoded = x_init + smart_explainer.x_init = x_init + smart_explainer.postprocessing_modifications = False + smart_explainer.features_imp_groups = None + smart_explainer.features_imp = None + smart_explainer.features_groups = {"group0": ["X1", "X2"]} + smart_explainer.contributions = [contrib, -contrib] + smart_explainer.features_dict = {"X1": "X1", "X2": "X2", "X3": "X3", "group0": "group0"} + smart_explainer.inv_features_dict = {"X1": "X1", "X2": "X2", "X3": "X3", "group0": "group0"} + smart_explainer.model = self.smart_explainer.model + smart_explainer.backend = ShapBackend(model=self.smart_explainer.model) + smart_explainer.backend.state = MultiDecorator(SmartState()) + smart_explainer.explain_data = None + smart_explainer._case, smart_explainer._classes = check_model(self.smart_explainer.model) + smart_explainer.state = smart_explainer.backend.state + smart_explainer.contributions_groups = smart_explainer.state.compute_grouped_contributions( + smart_explainer.contributions, smart_explainer.features_groups + ) + smart_explainer.features_imp_groups = smart_explainer.state.compute_features_import( + smart_explainer.contributions_groups + ) + + output = smart_explainer.plot.features_importance(mode="global-local", group_name="group0") + + assert len(output.data) == 3 + assert output.data[0].type == "bar" + assert output.data[1].type == "bar" + assert output.data[2].type == "bar" + def test_local_pred_1(self): xpl = self.smart_explainer xpl.proba_values = pd.DataFrame( data=np.array([[0.4, 0.6], [0.3, 0.7]]), columns=["class_1", "class_2"], index=xpl.x_encoded.index.values ) - output = xpl.plot.local_pred("person_A", label=0) + output = xpl.plot._local_pred("person_A", label=0) assert isinstance(output, float) def test_plot_line_comparison_1(self): @@ -1338,7 +1587,7 @@ def test_plot_line_comparison_1(self): ) ) expected_output = go.Figure(data=fig) - output = xpl.plot.plot_line_comparison( + output = xpl.plot._plot_line_comparison( ["person_A", "person_B"], var_dict, contributions, predictions=predictions, dict_features=features_dict ) @@ -1367,7 +1616,7 @@ def test_plot_line_comparison_2(self): ) predictions = [data.loc[id] for id in index] - output = xpl.plot.plot_line_comparison( + output = xpl.plot._plot_line_comparison( index, var_dict, contributions, diff --git a/tests/unit_tests/webapp/utils/test_callbacks.py b/tests/unit_tests/webapp/utils/test_callbacks.py index 492fba96..b12594c3 100644 --- a/tests/unit_tests/webapp/utils/test_callbacks.py +++ b/tests/unit_tests/webapp/utils/test_callbacks.py @@ -3,12 +3,13 @@ import numpy as np import pandas as pd -from dash import dcc +from dash import dcc, html from sklearn.tree import DecisionTreeClassifier from shapash import SmartExplainer from shapash.webapp.smart_app import SmartApp from shapash.webapp.utils.callbacks import ( + create_dropdown_feature_filter, create_filter_modalities_selection, create_id_card_data, create_id_card_layout, @@ -20,6 +21,7 @@ get_id_card_contrib, get_id_card_features, get_indexes_from_datatable, + handle_page_navigation, select_data_from_bool_filters, select_data_from_date_filters, select_data_from_numeric_filters, @@ -422,3 +424,33 @@ def test_create_filter_modalities_selection(self): "_column2", {"type": "var_dropdown", "index": 1}, self.smart_app.round_dataframe ) assert type(new_element.children) == dcc.Dropdown + + def test_handle_page_navigation_1(self): + page, selected_feature = handle_page_navigation( + triggered_input="page_left.n_clicks", page="3", selected_feature="column1" + ) + assert page == 2 + assert selected_feature == None + + def test_handle_page_navigation_2(self): + page, selected_feature = handle_page_navigation( + triggered_input="page_right.n_clicks", page="3", selected_feature="column1" + ) + assert page == 4 + assert selected_feature == None + + def test_handle_page_navigation_3(self): + page, selected_feature = handle_page_navigation( + triggered_input="bool_groups.on", page="3", selected_feature="column1" + ) + assert page == 1 + assert selected_feature == None + + def test_handle_page_navigation_4(self): + page, selected_feature = handle_page_navigation(triggered_input="erreur", page="3", selected_feature="column1") + assert page == 3 + assert selected_feature == "column1" + + def test_create_dropdown_feature_filter(self): + dropdown = create_dropdown_feature_filter(1, []) + assert type(dropdown) == html.Div From 5eb60f11e6db6021b63900798ced7336c558ba9c Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Mon, 16 Sep 2024 09:48:58 +0200 Subject: [PATCH 2/3] Feature importance tutorial update --- .../tuto-plot03-features-importance.ipynb | 80 ++++++++++++++++--- 1 file changed, 71 insertions(+), 9 deletions(-) diff --git a/tutorial/plots_and_charts/tuto-plot03-features-importance.ipynb b/tutorial/plots_and_charts/tuto-plot03-features-importance.ipynb index d2d815d7..e3279892 100644 --- a/tutorial/plots_and_charts/tuto-plot03-features-importance.ipynb +++ b/tutorial/plots_and_charts/tuto-plot03-features-importance.ipynb @@ -28,8 +28,6 @@ "metadata": {}, "outputs": [], "source": [ - "import pandas as pd\n", - "from category_encoders import OrdinalEncoder\n", "from sklearn.ensemble import ExtraTreesClassifier\n", "from sklearn.model_selection import train_test_split" ] @@ -249,7 +247,7 @@ "metadata": {}, "outputs": [], "source": [ - "clf = ExtraTreesClassifier(n_estimators=200).fit(Xtrain,ytrain)" + "clf = ExtraTreesClassifier(n_estimators=200, random_state=79).fit(Xtrain,ytrain['Survived'])" ] }, { @@ -300,12 +298,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "Backend: Shap TreeExplainer\n" + "INFO: Shap explainer type - \n" ] } ], "source": [ - "xpl.compile(x=Xtest)" + "xpl.compile(x=Xtest, y_target=ytest)" ] }, { @@ -394,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -408,15 +406,79 @@ "source": [ "xpl.plot.features_importance(max_features=3)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understand local effect\n", + "\n", + "### Comparing features globally\n", + "This plot allows us to observe how the importance of features varies across different subpopulations. For instance, we can see that in certain subpopulations, the **Port of Embarkation** has a greater impact than the **Ticket Class**, highlighting the local variations in feature significance." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xpl.plot.features_importance(mode='global-local', max_features=10, zoom=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Comparing the features by their shapeley values curves\n", + "In the plot below, we observe the same effect as before. For example, in certain subpopulations, the **Port of Embarkation** has a greater impact than the **Ticket Class**. This offers another way to visualize feature importance both locally and globally.\n", + "\n", + "When the curves cross each other, it indicates that one feature has a higher local effect in a specific subpopulation, but a lower global impact across the entire dataset. On the other hand, if a curve consistently remains higher than another, it signifies that the feature is more important both globally and locally.\n", + "\n", + "After this initial analysis, you can use the contribution plot to gain deeper insights into how a particular feature influences the model's predictions." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xpl.plot.features_importance(mode='cumulative', normalize_by_nb_samples=True, degree=-0.7, zoom=True)" + ] } ], "metadata": { "celltoolbar": "Aucun(e)", "hide_input": false, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "keltarif_39", "language": "python", - "name": "python3" + "name": "keltarif_39" }, "language_info": { "codemirror_mode": { @@ -428,7 +490,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.11" + "version": "3.9.18" } }, "nbformat": 4, From c6fe6113caf7131b6d3e7fe6dc895fbdfd8812ab Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Tue, 17 Sep 2024 09:55:19 +0200 Subject: [PATCH 3/3] incorrect indentation --- .pre-commit-config.yaml | 4 ++-- .readthedocs.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 038294a7..5c1355b3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,8 +26,8 @@ repos: exclude: ^(docs/|gdocs/) - id: check-added-large-files args: ['--maxkb=500'] - - id: no-commit-to-branch - args: ['--branch', 'master', '--branch', 'develop'] + - id: no-commit-to-branch + args: ['--branch', 'master', '--branch', 'develop'] - repo: https://github.com/psf/black rev: 21.12b0 hooks: diff --git a/.readthedocs.yml b/.readthedocs.yml index bf4c4630..cf0b28de 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -24,5 +24,5 @@ build: # Optionally set the version of Python and requirements required to build your docs python: - install: - - requirements: docs/requirements.txt + install: + - requirements: docs/requirements.txt