From 260b6e84b392431b9be182891265b57a846fc61c Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Tue, 3 Sep 2024 10:57:30 +0200 Subject: [PATCH 1/6] Add Pagination Feature for Feature Importance Plot --- shapash/explainer/smart_plotter.py | 40 +++- shapash/utils/explanation_metrics.py | 7 +- shapash/webapp/smart_app.py | 239 +++++++++++++++++------- shapash/webapp/utils/callbacks.py | 267 +++++++++++++++++++++++---- 4 files changed, 447 insertions(+), 106 deletions(-) diff --git a/shapash/explainer/smart_plotter.py b/shapash/explainer/smart_plotter.py index 708971a9..1741afa8 100644 --- a/shapash/explainer/smart_plotter.py +++ b/shapash/explainer/smart_plotter.py @@ -534,7 +534,8 @@ def _create_jittered_points( - A numpy array of jittered points. """ # Creating jittered points - jitter = np.random.normal(mean, std, len(percentages)) + rng = np.random.default_rng(seed=79) + jitter = rng.normal(mean, std, len(percentages)) if np.isnan(percentages).any(): percentages.fill(1) @@ -583,7 +584,8 @@ def prepare_hover_text(self, feature_values, pred, feature_name): def _add_violin_trace(self, fig, name, x, y, side, line_color, hovertext, secondary_y=True): """Adds a Violin trace to the figure.""" # Violin plot has a problem if for one violin all the points have the same contribution value - y = y + np.random.normal(size=y.shape) * (max(y.max(), 0) - min(y.min(), 0)) / 10 ** 8 + rng = np.random.default_rng(seed=79) + y = y + rng.normal(size=y.shape) * (max(y.max(), 0) - min(y.min(), 0)) / 10 ** 8 violin_trace = go.Violin( name=name, x=x, @@ -969,6 +971,8 @@ def plot_features_import( width=width, height=height, title=dict_t, + # xaxis=dict(range=[0, 0.5]), + # yaxis=dict(autorange=True), xaxis_title=dict_xaxis, yaxis_title=dict_yaxis, hovermode="closest", @@ -1718,6 +1722,7 @@ def contribution_plot( def features_importance( self, max_features=20, + page="top", selection=None, label=-1, group_name=None, @@ -1742,6 +1747,8 @@ def features_importance( max_features: int (optional, default 20) this argument limit the number of hbar in features importance plot if max_features is 20, plot selects the 20 most important features + page: int, str (optional, default "top") + enable to select the features to plot between "top", "worse" or the number of the page selection: list (optional, default None) This argument allows to represent the importance calculated with a subset. Subset features importance is compared to global in the plot @@ -1777,6 +1784,24 @@ def features_importance( -------- >>> xpl.plot.features_importance() """ + + def get_feature_importance_page(features_importance, page, max_features): + if isinstance(page, int): + nb_features = len(features_importance) + nb_page_max = nb_features // max_features + 1 + page = (page - 1) % nb_page_max + 1 + + if (page == "top") or (page == 1): + return features_importance.tail(max_features) + elif page == "worst": + return features_importance.head(max_features) + elif isinstance(page, int): + start_index = (page - 1) * max_features + end_index = start_index + max_features + return features_importance.iloc[-end_index:-start_index] + else: + raise ValueError("Invalid value for page. It must be 'top', 'worst', or an integer.") + self.explainer.compute_features_import(force=force) subtitle = None title = "Features Importance" @@ -1809,7 +1834,7 @@ def features_importance( # classification if self.explainer._case == "classification": label_num, _, label_value = self.explainer.check_label_name(label) - global_feat_imp = features_importance[label_num].tail(max_features) + global_feat_imp = get_feature_importance_page(features_importance[label_num], page, max_features) if selection is not None: subset_feat_imp = self.explainer.backend.get_global_features_importance( contributions=contributions[label_num], explain_data=self.explainer.explain_data, subset=selection @@ -1817,9 +1842,10 @@ def features_importance( else: subset_feat_imp = None subtitle = f"Response: {label_value}" + # regression elif self.explainer._case == "regression": - global_feat_imp = features_importance.tail(max_features) + global_feat_imp = get_feature_importance_page(features_importance, page, max_features) if selection is not None: subset_feat_imp = self.explainer.backend.get_global_features_importance( contributions=contributions, explain_data=self.explainer.explain_data, subset=selection @@ -3652,10 +3678,10 @@ def _prediction_classification_plot( df_wrong_predict.target.values.flatten(), ) ] - + rng = np.random.default_rng(seed=79) fig.add_trace( go.Scatter( - x=df_correct_predict["target"].values.flatten() + np.random.normal(0, 0.02, len(df_correct_predict)), + x=df_correct_predict["target"].values.flatten() + rng.normal(0, 0.02, len(df_correct_predict)), y=df_correct_predict["proba_values"].values.flatten(), mode="markers", marker_color=self._style_dict["prediction_plot"][1], @@ -3669,7 +3695,7 @@ def _prediction_classification_plot( fig.add_trace( go.Scatter( - x=df_wrong_predict["target"].values.flatten() + np.random.normal(0, 0.02, len(df_wrong_predict)), + x=df_wrong_predict["target"].values.flatten() + rng.normal(0, 0.02, len(df_wrong_predict)), y=df_wrong_predict["proba_values"].values.flatten(), mode="markers", marker_color=self._style_dict["prediction_plot"][0], diff --git a/shapash/utils/explanation_metrics.py b/shapash/utils/explanation_metrics.py index d7479033..f1232f31 100644 --- a/shapash/utils/explanation_metrics.py +++ b/shapash/utils/explanation_metrics.py @@ -97,7 +97,8 @@ def _get_radius(dataset, n_neighbors, sample_size=500, percentile=95): # Select 500 points max to sample size = min([dataset.shape[0], sample_size]) # Randomly sample points from dataset - sampled_instances = dataset[np.random.randint(0, dataset.shape[0], size), :] + rng = np.random.default_rng(seed=79) + sampled_instances = dataset[rng.randint(0, dataset.shape[0], size), :] # Define normalization vector mean_vector = np.array(dataset, dtype=np.float32).std(axis=0) # Initialize the similarity matrix @@ -109,9 +110,9 @@ def _get_radius(dataset, n_neighbors, sample_size=500, percentile=95): similarity_distance[i, j] = dist similarity_distance[j, i] = dist # Select top n_neighbors - ordered_X = np.sort(similarity_distance)[:, 1 : n_neighbors + 1] + ordered_x = np.sort(similarity_distance)[:, 1 : n_neighbors + 1] # Select the value of the distance that captures XX% of all distances (percentile) - return np.percentile(ordered_X.flatten(), percentile) + return np.percentile(ordered_x.flatten(), percentile) def find_neighbors(selection, dataset, model, mode, n_neighbors=10): diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index 036768ad..963bee83 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -1,6 +1,7 @@ """ Main class of Web application Shapash """ + import copy import random import re @@ -18,25 +19,29 @@ from shapash.utils.utils import truncate_str from shapash.webapp.utils.callbacks import ( + adjust_figure_layout, create_dropdown_feature_filter, create_filter_modalities_selection, create_id_card_data, create_id_card_layout, + determine_total_pages_and_display, get_feature_contributions_sign_to_show, get_feature_filter_options, get_feature_from_clicked_data, - get_feature_from_features_groups, get_figure_zoom, - get_group_name, get_id_card_contrib, get_id_card_features, get_indexes_from_datatable, + get_selected_feature, + handle_group_display_logic, + handle_page_navigation, + plot_features_importance, select_data_from_bool_filters, select_data_from_date_filters, select_data_from_numeric_filters, select_data_from_prediction_picking, select_data_from_str_filters, - update_click_data_on_subset_changes, + update_click_data_on_subset_changes_if_needed, update_features_to_display, ) from shapash.webapp.utils.explanations import Explanations @@ -795,28 +800,101 @@ def make_skeleton(self): dbc.Card( [ html.Div( - # To drow the global_feature_importance graph + # To draw the global_feature_importance graph self.draw_component("graph", "global_feature_importance"), id="card_global_feature_importance", # Position must be absolute to add the explanation button style={"position": "absolute"}, ), dcc.Store(id="clickdata-store"), + dcc.Store(id="selected-clickdata-store"), html.Div( [ - # Create explanation button on feature importance graph - dbc.Button( - "?", id="open_feature_importance", size="sm", color="warning" + # Create a row to contain the buttons + dbc.Row( + [ + # Placeholder column to center the second button + dbc.Col( + dbc.Button( + "Go Back", + id="goback_feature_importance", + size="sm", + color="warning", + style={"display": "none"}, + ), + width=2, + className="d-flex justify-content-start", + ), + # Centered button column + dbc.Col( + html.Div( + [ + dbc.Button( + "<", + id="page_left", + size="sm", + color="warning", + style={ + "margin": "5px 0 5px 0", + "font-size": "12px", + "line-height": "1.2", + }, + ), + html.Div( + [ + html.Span( + "1", id="page_feature_importance" + ), + html.Span(" / ", id="separator"), + html.Span("1", id="total_pages"), + ], + style={ + "padding": "0 10px", + "align-self": "center", + }, + ), + dbc.Button( + ">", + id="page_right", + size="sm", + color="warning", + style={ + "margin": "5px 0 5px 0", + "font-size": "12px", + "line-height": "1.2", + }, + ), + ], + id="page_viewer_feature_importance", + style={"display": "none"}, + ), + width=8, + className="d-flex justify-content-center", + ), + # First button column + dbc.Col( + dbc.Button( + "?", + id="open_feature_importance", + size="sm", + color="warning", + ), + width=2, + className="d-flex justify-content-end", + ), + ], + className="g-0", + align="center", + style={"width": "100%"}, ), - # Create popover for this button + # Create popover for the first button dbc.Popover( - "Click here to have more \ - information on Feature Importance graph.", + "Click here to have more information on Feature Importance graph.", target="open_feature_importance", body=True, trigger="hover", ), - # Create modal associated to this button + # Create modal associated to the first button dbc.Modal( [ # Modal title @@ -831,7 +909,7 @@ def make_skeleton(self): html.A( "Click here for more details", href="https://github.com/MAIF/shapash/blob/master/tutorial/plots_and_charts/tuto-plot03-features-importance.ipynb", - # open new brother tab + # open new browser tab target="_blank", style={"color": self.color[0]}, ), @@ -850,7 +928,11 @@ def make_skeleton(self): ), ], # position must be relative - style={"position": "relative", "left": "96%"}, + style={ + "position": "relative", + "display": "flex", + "justify-content": "space-between", + }, ), ] ) @@ -986,7 +1068,7 @@ def make_skeleton(self): id="modal_prediction_picking", centered=True, size="lg", - ) + ), # Position must be relative ], style={"position": "relative", "left": "97%"}, @@ -1623,6 +1705,11 @@ def update_datatable( Output("global_feature_importance", "figure"), Output("global_feature_importance", "clickData"), Output("clickdata-store", "data"), + Output("selected-clickdata-store", "data"), + Output("page_feature_importance", "children"), + Output("total_pages", "children"), + Output("page_viewer_feature_importance", "style"), + Output("goback_feature_importance", "style"), ], [ Input("select_label", "value"), @@ -1634,12 +1721,17 @@ def update_datatable( Input("card_global_feature_importance", "n_clicks"), Input("bool_groups", "on"), Input("ember_global_feature_importance", "n_clicks"), + Input("page_left", "n_clicks"), + Input("page_right", "n_clicks"), + Input("goback_feature_importance", "n_clicks"), ], [ State("global_feature_importance", "clickData"), - State("global_feature_importance", "figure"), + State("global_feature_importance", "selectedData"), State("features", "value"), State("clickdata-store", "data"), + State("selected-clickdata-store", "data"), + State("page_feature_importance", "children"), ], ) def update_feature_importance( @@ -1652,10 +1744,15 @@ def update_feature_importance( n_clicks, bool_group, click_zoom, - clickData, - figure, + click_page_left, + click_page_right, + click_goback, + click_data, + selected_click_data, features, - clickData_store, + click_data_store, + selected_click_data_store, + page, ): """ update feature importance plot according label, click on graph, @@ -1671,7 +1768,6 @@ def update_feature_importance( bool_group: display groups click_zoom: click on zoom button clickData: click on features importance graph - figure: figure of Features Importance graph features: features value clickData_store: previous click on features importance graph ------------------------------------------------------------- @@ -1681,59 +1777,76 @@ def update_feature_importance( previous click on Features Importance graph """ ctx = dash.callback_context - # Zoom is False by Default. It becomes True if we click on it - zoom_active = get_figure_zoom(click_zoom) - selection = None - list_index = self.list_index - if clickData is not None and ( - ctx.triggered[0]["prop_id"] - in ["apply_filter.n_clicks", "reset_dropdown_button.n_clicks", "dataset.data"] - or ("del_dropdown_button" in ctx.triggered[0]["prop_id"] and None not in nclicks_del) - ): - clickData = update_click_data_on_subset_changes(clickData) - selected_feature = ( - self.explainer.inv_features_dict.get(get_feature_from_clicked_data(clickData)) if clickData else None + # Determine which triggered input + triggered_input = ctx.triggered[0]["prop_id"] + + # Handle click_data updates based on filters or dataset changes + click_data = update_click_data_on_subset_changes_if_needed(click_data, triggered_input, nclicks_del) + selected_click_data = update_click_data_on_subset_changes_if_needed( + selected_click_data, triggered_input, nclicks_del ) - if self.explainer.features_groups and bool_group: - if ctx.triggered[0]["prop_id"] == "card_global_feature_importance.n_clicks": - # When we click twice on the same bar this will reset the graph - if clickData_store == clickData: - selected_feature = None - selected_feature = get_feature_from_features_groups( - selected_feature, self.explainer.features_groups - ) - elif ctx.triggered[0]["prop_id"] == "ember_global_feature_importance.n_clicks": - selected_feature = get_feature_from_features_groups( - selected_feature, self.explainer.features_groups - ) - else: - selected_feature = None + # Get selected feature from click_data + selected_feature = get_selected_feature(click_data, self.explainer.inv_features_dict) + + # Handle page navigation + page, selected_feature = handle_page_navigation(triggered_input, page, selected_feature) + + # Handle group display logic + selected_feature, group_name, click_data, selected_click_data = handle_group_display_logic( + bool_group, + triggered_input, + selected_feature, + selected_click_data, + click_data, + click_data_store, + selected_click_data_store, + self.explainer.features_groups, + self.explainer.features_dict, + ) - selection = get_indexes_from_datatable(data, list_index) + # Get selection indexes from datatable + selection = get_indexes_from_datatable(data, self.list_index) - group_name = get_group_name(selected_feature, self.explainer.features_groups) + # Plot features importance + figure = plot_features_importance( + self.explainer, features, page, selection, label, group_name, bool_group, click_zoom + ) - figure = self.explainer.plot.features_importance( - max_features=features, - selection=selection, - label=label, - group_name=group_name, - display_groups=bool_group, - zoom=zoom_active, + # Determine total pages and display settings + total_pages, display_page, page = determine_total_pages_and_display( + self.explainer, features, bool_group, group_name, page ) - # Adjust graph with adding x axis title - MyGraph.adjust_graph_static(figure, x_ax="Mean absolute Contribution") - figure.layout.clickmode = "event+select" + + if group_name and selected_feature is group_name: + selected_feature = None + + if group_name: + goback_feature_importance = {} + else: + goback_feature_importance = {"display": "none"} + if selected_feature: - self.select_point(figure, clickData) + self.select_point(figure, click_data) + + # Adjust figure layout + adjust_figure_layout(figure) - # font size can be adapted to screen size - nb_car = max([len(figure.data[0].y[i]) for i in range(len(figure.data[0].y))]) - figure.update_layout(yaxis=dict(tickfont={"size": min(round(500 / nb_car), 12)})) - clickData_store = clickData.copy() if clickData is not None else None - return figure, clickData, clickData_store + # Update clickData store + click_data_store = click_data.copy() if click_data is not None else None + selected_click_data_store = selected_click_data.copy() if selected_click_data is not None else None + + return ( + figure, + click_data, + click_data_store, + selected_click_data_store, + page, + str(total_pages), + display_page, + goback_feature_importance, + ) @app.callback( Output(component_id="feature_selector", component_property="figure"), diff --git a/shapash/webapp/utils/callbacks.py b/shapash/webapp/utils/callbacks.py index c5748e4a..6b6e97d4 100644 --- a/shapash/webapp/utils/callbacks.py +++ b/shapash/webapp/utils/callbacks.py @@ -1,10 +1,17 @@ import datetime -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple, Union + +if TYPE_CHECKING: + from shapash.explainer.smart_explainer import SmartExplainer import dash_bootstrap_components as dbc import numpy as np import pandas as pd from dash import dcc, html +from dash.exceptions import PreventUpdate +from plotly.graph_objs import Figure + +from shapash.webapp.utils.MyGraph import MyGraph def select_data_from_prediction_picking(round_dataframe: pd.DataFrame, selected_data: dict) -> pd.DataFrame: @@ -240,31 +247,10 @@ def get_feature_from_features_groups(selected_feature: Optional[str], features_g if selected_feature in list_sub_features: for k, v in features_groups.items(): if selected_feature in v: - selected_feature = k + return k return selected_feature -def get_group_name(selected_feature: Optional[str], features_groups: Optional[dict]) -> Optional[str]: - """Get the group feature name if the selected feature is one of the groups. - - Parameters - ---------- - selected_feature : Optional[str] - Selected feature - features_groups : Optional[dict] - Groups names and corresponding list of features - - Returns - ------- - Optional[str] - Group feature name - """ - group_name = ( - selected_feature if (features_groups is not None and selected_feature in features_groups.keys()) else None - ) - return group_name - - def get_indexes_from_datatable(data: list, list_index: Optional[list] = None) -> Optional[list]: """Get the indexes of the data. If list_index is given and is the same length than the indexes, there is no subset selected. @@ -522,16 +508,20 @@ def create_id_card_layout(selected_data: pd.DataFrame, additional_features_dict: dbc.Col(dbc.Label(row["feature_name"]), width=3, style=label_style), dbc.Col(dbc.Label(row["feature_value"]), width=5, className="id_card_solid"), dbc.Col(width=1), - dbc.Col( - dbc.Row( - dbc.Label(format(row["feature_contrib"], ".4f"), width="auto", style={"padding-top": 0}), - justify="end", - ), - width=2, - className="id_card_solid", - ) - if not np.isnan(row["feature_contrib"]) - else None, + ( + dbc.Col( + dbc.Row( + dbc.Label( + format(row["feature_contrib"], ".4f"), width="auto", style={"padding-top": 0} + ), + justify="end", + ), + width=2, + className="id_card_solid", + ) + if not np.isnan(row["feature_contrib"]) + else None + ), ] ) ) @@ -704,3 +694,214 @@ def create_filter_modalities_selection(value: str, id: dict, round_dataframe: pd ) return new_element + + +def handle_page_navigation(triggered_input: str, page: Union[int, str], selected_feature: str) -> Tuple[int, str]: + """ + Handle the navigation between different pages based on user input. + + Args: + triggered_input (str): The input that triggered the navigation. + page (Union[int, str]): The current page number. + selected_feature (str): The currently selected feature. + + Returns: + Tuple[int, str]: Updated page number and selected feature. + """ + page = int(page) + if triggered_input == "page_left.n_clicks": + page -= 1 + selected_feature = None + elif triggered_input == "page_right.n_clicks": + page += 1 + selected_feature = None + elif triggered_input == "bool_groups.on": + page = 1 + selected_feature = None + return page, selected_feature + + +def update_click_data_on_subset_changes_if_needed(click_data: dict, triggered_input: str, nclicks_del: list) -> dict: + """ + Update the click data when there are changes in the subset of data. + + Args: + click_data (dict): The current click data. + triggered_input (str): The input that triggered the update. + nclicks_del (list): The number of delete clicks. + + Returns: + dict: Updated click data. + """ + if click_data and ( + triggered_input in ["apply_filter.n_clicks", "reset_dropdown_button.n_clicks", "dataset.data"] + or ("del_dropdown_button" in triggered_input and None not in nclicks_del) + ): + click_data = update_click_data_on_subset_changes(click_data) + return click_data + + +def get_selected_feature(click_data: dict, inv_features_dict: dict) -> str: + """ + Retrieve the selected feature from the click data. + + Args: + click_data (dict): The click data. + inv_features_dict (dict): Dictionary mapping feature IDs to feature names. + + Returns: + str: The selected feature, if any. + """ + return inv_features_dict.get(get_feature_from_clicked_data(click_data)) if click_data else None + + +def handle_group_display_logic( + bool_group: bool, + triggered_input: str, + selected_feature: str, + selected_click_data, + click_data: dict, + click_data_store: dict, + selected_click_data_store, + features_groups: dict, + features_dict: dict, +) -> Tuple[str, str, dict]: + """ + Handle the display logic for feature groups. + + Args: + bool_group (bool): Whether to display feature groups. + triggered_input (str): The input that triggered the update. + selected_feature (str): The currently selected feature. + click_data (dict): The current click data. + click_data_store (dict): Stored click data. + features_groups (dict): Dictionary of feature groups. + features_dict (dict): Dictionary of features. + + Returns: + Tuple[str, str, dict]: Updated selected feature, group name, and click data. + """ + group_name = None + selected_feature_group = None + if features_groups and bool_group: + if triggered_input in ["card_global_feature_importance.n_clicks", "ember_global_feature_importance.n_clicks"]: + selected_feature_group = get_feature_from_features_groups(selected_feature, features_groups) + else: + selected_feature = None + + group_name = ( + selected_feature_group + if (features_groups is not None and selected_feature_group in features_groups.keys()) + else None + ) + + print("selected_click_data_store == selected_click_data: ", selected_click_data_store == selected_click_data) + print("click_data_store == click_data: ", click_data_store == click_data) + print("selected_feature: ", selected_feature) + print("group_name", group_name) + print("features_dict.get(group_name, group_name)", features_dict.get(group_name, group_name)) + + if (triggered_input == "card_global_feature_importance.n_clicks") and ( + selected_click_data_store == selected_click_data + ): + raise PreventUpdate + + if (triggered_input == "card_global_feature_importance.n_clicks") and (click_data_store == click_data): + if group_name and selected_feature and selected_feature != group_name: + click_data["points"][0]["label"] = features_dict.get(group_name, group_name) + else: + click_data = None + group_name = None + selected_feature = None + elif triggered_input == "goback_feature_importance.n_clicks": + selected_click_data = None + click_data = None + group_name = None + selected_feature = None + + return selected_feature, group_name, click_data, selected_click_data + + +def plot_features_importance( + explainer: "SmartExplainer", + features: int, + page: int, + selection: list, + label: Union[int, str], + group_name: str, + bool_group: bool, + click_zoom: bool, +) -> Figure: + """ + Plot the features importance graph. + + Args: + explainer (SmartExplainer): The explainer object. + features (int): Number of features to display. + page (int): Current page number. + selection (list): List of selected features. + label (Union[int, str]): Label for the plot. + group_name (str): Name of the feature group. + bool_group (bool): Whether to display groups. + click_zoom (bool): Whether zoom is enabled. + + Returns: + Figure: The features importance plot. + """ + page_to_plot = 1 if group_name else page + # Zoom is False by Default. It becomes True if we click on it + zoom_active = get_figure_zoom(click_zoom) + return explainer.plot.features_importance( + max_features=features, + page=page_to_plot, + selection=selection, + label=label, + group_name=group_name, + display_groups=bool_group, + zoom=zoom_active, + ) + + +def determine_total_pages_and_display( + explainer: "SmartExplainer", features: int, bool_group: bool, group_name: str, page: int +) -> Tuple[int, str, int]: + """ + Determine the total number of pages and the display properties. + + Args: + explainer (SmartExplainer): The explainer object. + features (int): Number of features to display per page. + bool_group (bool): Whether to display groups. + group_name (str): Name of the feature group. + page (int): Current page number. + + Returns: + Tuple[int, str, int]: Total pages, display properties, and updated page number. + """ + display_groups = explainer.features_groups is not None and bool_group + nb_features = len(explainer.features_imp_groups) if display_groups else len(explainer.features_imp) + total_pages = nb_features // features + 1 + if (total_pages == 1) or (group_name): + display_page = {"display": "none"} + else: + display_page = {"display": "flex"} + page = (page - 1) % total_pages + 1 + + return total_pages, display_page, page + + +def adjust_figure_layout(figure: Figure) -> None: + """ + Adjust the layout of the figure. + + Args: + figure (Figure): The figure to adjust. + + Returns: + None + """ + MyGraph.adjust_graph_static(figure, x_ax="Mean absolute Contribution") + figure.layout.clickmode = "event+select" + + nb_car = max([len(figure.data[0].y[i]) for i in range(len(figure.data[0].y))]) + figure.update_layout(yaxis=dict(tickfont={"size": min(round(500 / nb_car), 12)})) From 44852b297a5052e1fb043120b67e0c7d350051b6 Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Tue, 3 Sep 2024 11:48:54 +0200 Subject: [PATCH 2/6] remove old test --- tests/unit_tests/webapp/utils/test_callbacks.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/unit_tests/webapp/utils/test_callbacks.py b/tests/unit_tests/webapp/utils/test_callbacks.py index e90fcbb4..492fba96 100644 --- a/tests/unit_tests/webapp/utils/test_callbacks.py +++ b/tests/unit_tests/webapp/utils/test_callbacks.py @@ -17,7 +17,6 @@ get_feature_from_clicked_data, get_feature_from_features_groups, get_figure_zoom, - get_group_name, get_id_card_contrib, get_id_card_features, get_indexes_from_datatable, @@ -242,14 +241,6 @@ def test_get_feature_from_features_groups(self): feature = get_feature_from_features_groups("column2", features_groups) assert feature == "column2" - def test_get_group_name(self): - features_groups = {"A": ["column1", "column3"]} - feature = get_group_name("A", features_groups) - assert feature == "A" - - feature = get_group_name("column3", features_groups) - assert feature == None - def test_get_indexes_from_datatable(self): data = self.smart_app.components["table"]["dataset"].data subset = get_indexes_from_datatable(data) From a8151c3e8f072aefe1b4aaf908877e52b86fca5d Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Tue, 3 Sep 2024 13:59:43 +0200 Subject: [PATCH 3/6] fix random number generation --- shapash/utils/explanation_metrics.py | 2 +- .../utils/test_explanation_metrics.py | 24 ++++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/shapash/utils/explanation_metrics.py b/shapash/utils/explanation_metrics.py index f1232f31..41c9bc41 100644 --- a/shapash/utils/explanation_metrics.py +++ b/shapash/utils/explanation_metrics.py @@ -98,7 +98,7 @@ def _get_radius(dataset, n_neighbors, sample_size=500, percentile=95): size = min([dataset.shape[0], sample_size]) # Randomly sample points from dataset rng = np.random.default_rng(seed=79) - sampled_instances = dataset[rng.randint(0, dataset.shape[0], size), :] + sampled_instances = dataset[rng.integers(0, dataset.shape[0], size), :] # Define normalization vector mean_vector = np.array(dataset, dtype=np.float32).std(axis=0) # Initialize the similarity matrix diff --git a/tests/unit_tests/utils/test_explanation_metrics.py b/tests/unit_tests/utils/test_explanation_metrics.py index 4fadc4e5..2ebdef9b 100644 --- a/tests/unit_tests/utils/test_explanation_metrics.py +++ b/tests/unit_tests/utils/test_explanation_metrics.py @@ -30,10 +30,11 @@ def test_compute_distance(self): epsilon = 0 expected = 0.5 t = _compute_distance(x1, x2, mean_vector, epsilon) - assert t == expected + assert np.isclose(t, expected) def test_compute_similarities(self): - df = pd.DataFrame(np.random.randint(0, 100, size=(5, 4)), columns=list("ABCD")).values + rng = np.random.default_rng(seed=79) + df = pd.DataFrame(rng.integers(0, 100, size=(5, 4)), columns=list("ABCD")).values instance = df[0, :] expected_len = 5 expected_dist = 0 @@ -42,12 +43,14 @@ def test_compute_similarities(self): assert t[0] == expected_dist def test_get_radius(self): - df = pd.DataFrame(np.random.randint(0, 100, size=(5, 4)), columns=list("ABCD")).values + rng = np.random.default_rng(seed=79) + df = pd.DataFrame(rng.integers(0, 100, size=(5, 4)), columns=list("ABCD")).values t = _get_radius(df, n_neighbors=3) assert t > 0 def test_find_neighbors(self): - df = pd.DataFrame(np.random.randint(0, 100, size=(15, 4)), columns=list("ABCD")) + rng = np.random.default_rng(seed=79) + df = pd.DataFrame(rng.integers(0, 100, size=(15, 4)), columns=list("ABCD")) selection = [1, 3] X = df.iloc[:, :-1] y = df.iloc[:, -1] @@ -58,9 +61,10 @@ def test_find_neighbors(self): assert t[0].shape[1] == X.shape[1] + 2 def test_shap_neighbors(self): - df = pd.DataFrame(np.random.randint(0, 100, size=(15, 4)), columns=list("ABCD")) - contrib = pd.DataFrame(np.random.randint(10, size=(15, 4)), columns=list("EFGH")) - instance = df.values[:2, :] + rng = np.random.default_rng(seed=79) + df = pd.DataFrame(rng.integers(0, 100, size=(15, 4)), columns=list("ABCD")) + contrib = pd.DataFrame(rng.integers(10, size=(15, 4)), columns=list("EFGH")) + instance = df.iloc[:2, :].values extra_cols = np.repeat(np.array([0, 0]), 2).reshape(2, -1) instance = np.append(instance, extra_cols, axis=1) mode = "regression" @@ -70,7 +74,8 @@ def test_shap_neighbors(self): assert t[2].shape == (len(df.columns),) def test_get_min_nb_features(self): - contrib = pd.DataFrame(np.random.randint(10, size=(15, 4)), columns=list("ABCD")) + rng = np.random.default_rng(seed=79) + contrib = pd.DataFrame(rng.integers(10, size=(15, 4)), columns=list("ABCD")) selection = [1, 3] distance = 0.1 mode = "regression" @@ -80,7 +85,8 @@ def test_get_min_nb_features(self): assert len(t) == len(selection) def test_get_distance(self): - contrib = pd.DataFrame(np.random.randint(10, size=(15, 4)), columns=list("ABCD")) + rng = np.random.default_rng(seed=79) + contrib = pd.DataFrame(rng.integers(10, size=(15, 4)), columns=list("ABCD")) selection = [1, 3] nb_features = 2 mode = "regression" From 4c9b2eccd84dde589b1a08561a3feff33a4e4f84 Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Tue, 3 Sep 2024 14:26:31 +0200 Subject: [PATCH 4/6] remove print --- shapash/webapp/utils/callbacks.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/shapash/webapp/utils/callbacks.py b/shapash/webapp/utils/callbacks.py index 6b6e97d4..e19b853f 100644 --- a/shapash/webapp/utils/callbacks.py +++ b/shapash/webapp/utils/callbacks.py @@ -795,12 +795,6 @@ def handle_group_display_logic( else None ) - print("selected_click_data_store == selected_click_data: ", selected_click_data_store == selected_click_data) - print("click_data_store == click_data: ", click_data_store == click_data) - print("selected_feature: ", selected_feature) - print("group_name", group_name) - print("features_dict.get(group_name, group_name)", features_dict.get(group_name, group_name)) - if (triggered_input == "card_global_feature_importance.n_clicks") and ( selected_click_data_store == selected_click_data ): From 0a2aba6a6e0f90c06395a2cad3e7217b30bae352 Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Tue, 3 Sep 2024 14:41:25 +0200 Subject: [PATCH 5/6] remove commented lines --- shapash/explainer/smart_plotter.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/shapash/explainer/smart_plotter.py b/shapash/explainer/smart_plotter.py index 1741afa8..3497ecb0 100644 --- a/shapash/explainer/smart_plotter.py +++ b/shapash/explainer/smart_plotter.py @@ -971,8 +971,6 @@ def plot_features_import( width=width, height=height, title=dict_t, - # xaxis=dict(range=[0, 0.5]), - # yaxis=dict(autorange=True), xaxis_title=dict_xaxis, yaxis_title=dict_yaxis, hovermode="closest", From 18565d81e4de7391618646fae3ff612f2cb5095b Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Tue, 3 Sep 2024 15:02:40 +0200 Subject: [PATCH 6/6] refactoring --- shapash/webapp/smart_app.py | 13 +++++++--- shapash/webapp/utils/callbacks.py | 40 ------------------------------- 2 files changed, 10 insertions(+), 43 deletions(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index 963bee83..acbb119e 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -35,7 +35,6 @@ get_selected_feature, handle_group_display_logic, handle_page_navigation, - plot_features_importance, select_data_from_bool_filters, select_data_from_date_filters, select_data_from_numeric_filters, @@ -1810,8 +1809,16 @@ def update_feature_importance( selection = get_indexes_from_datatable(data, self.list_index) # Plot features importance - figure = plot_features_importance( - self.explainer, features, page, selection, label, group_name, bool_group, click_zoom + page_to_plot = 1 if group_name else page + zoom_active = get_figure_zoom(click_zoom) + figure = self.explainer.plot.features_importance( + max_features=features, + page=page_to_plot, + selection=selection, + label=label, + group_name=group_name, + display_groups=bool_group, + zoom=zoom_active, ) # Determine total pages and display settings diff --git a/shapash/webapp/utils/callbacks.py b/shapash/webapp/utils/callbacks.py index e19b853f..8fbf9c95 100644 --- a/shapash/webapp/utils/callbacks.py +++ b/shapash/webapp/utils/callbacks.py @@ -816,46 +816,6 @@ def handle_group_display_logic( return selected_feature, group_name, click_data, selected_click_data -def plot_features_importance( - explainer: "SmartExplainer", - features: int, - page: int, - selection: list, - label: Union[int, str], - group_name: str, - bool_group: bool, - click_zoom: bool, -) -> Figure: - """ - Plot the features importance graph. - - Args: - explainer (SmartExplainer): The explainer object. - features (int): Number of features to display. - page (int): Current page number. - selection (list): List of selected features. - label (Union[int, str]): Label for the plot. - group_name (str): Name of the feature group. - bool_group (bool): Whether to display groups. - click_zoom (bool): Whether zoom is enabled. - - Returns: - Figure: The features importance plot. - """ - page_to_plot = 1 if group_name else page - # Zoom is False by Default. It becomes True if we click on it - zoom_active = get_figure_zoom(click_zoom) - return explainer.plot.features_importance( - max_features=features, - page=page_to_plot, - selection=selection, - label=label, - group_name=group_name, - display_groups=bool_group, - zoom=zoom_active, - ) - - def determine_total_pages_and_display( explainer: "SmartExplainer", features: int, bool_group: bool, group_name: str, page: int ) -> Tuple[int, str, int]: