From 2f41ef3429634ebbff868e79bc286f749c3b8c71 Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Thu, 4 May 2023 09:39:58 +0200 Subject: [PATCH 01/12] Revome self. from callbacks. --- shapash/webapp/smart_app.py | 205 +++++++++++++++++--------------- shapash/webapp/utils/MyGraph.py | 11 +- 2 files changed, 113 insertions(+), 103 deletions(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index ff3531c5..b36970c8 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -707,6 +707,7 @@ def make_skeleton(self): # Position must be absolute to add the explanation button style={"position": 'absolute'} ), + dcc.Store(id="clickdata-store"), html.Div([ # Create explanation button on feature importance graph dbc.Button( @@ -1178,8 +1179,8 @@ def draw_filter(self): ] return filter - def select_point(self, - graph, + @staticmethod + def select_point(figure, click_data): """ Method which set the selected point in graph component @@ -1188,9 +1189,8 @@ def select_point(self, if click_data: curve_id = click_data['points'][0]['curveNumber'] point_id = click_data['points'][0]['pointIndex'] - for curve in range( - len(self.components['graph'][graph].figure['data'])): - self.components['graph'][graph].figure['data'][curve].selectedpoints = \ + for curve in range(len(figure['data'])): + figure['data'][curve].selectedpoints = \ [point_id] if curve == curve_id else [] def callback_fullscreen_buttons(self): @@ -1426,6 +1426,8 @@ def update_datatable(selected_data, data: available dataset tooltip_data: tooltip of the dataset columns: columns of the dataset + filtered_subset_info: subset size + filtered_subset_color: subset warning color """ ctx = dash.callback_context df = self.round_dataframe @@ -1521,16 +1523,16 @@ def update_datatable(selected_data, df = self.round_dataframe else: raise dash.exceptions.PreventUpdate - self.components['table']['dataset'].data = df.to_dict('records') - self.components['table']['dataset'].tooltip_data = [ + data = df.to_dict('records') + tooltip_data = [ { column: {'value': str(value), 'type': 'text'} for column, value in row.items() } for row in df.to_dict('rows') ] return ( - self.components['table']['dataset'].data, - self.components['table']['dataset'].tooltip_data, + data, + tooltip_data, columns, filtered_subset_info, filtered_subset_color, @@ -1539,7 +1541,8 @@ def update_datatable(selected_data, @app.callback( [ Output('global_feature_importance', 'figure'), - Output('global_feature_importance', 'clickData') + Output('global_feature_importance', 'clickData'), + Output('clickdata-store', 'data') ], [ Input('select_label', 'value'), @@ -1555,7 +1558,9 @@ def update_datatable(selected_data, ], [ State('global_feature_importance', 'clickData'), - State('features', 'value') + State('global_feature_importance', 'figure'), + State('features', 'value'), + State('clickdata-store', 'data') ] ) def update_feature_importance(label, @@ -1569,7 +1574,9 @@ def update_feature_importance(label, bool_group, click_zoom, clickData, - features): + figure, + features, + clickData_store): """ update feature importance plot according label, click on graph, filters applied and subset selected in prediction picking graph. @@ -1585,11 +1592,14 @@ def update_feature_importance(label, bool_group: display groups click_zoom: click on zoom button clickData: click on features importance graph + figure: figure of Features Importance graph features: features value + clickData_store: previous click on features importance graph ------------------------------------------------------------- return figure of Features Importance graph click on Features Importance graph + previous click on Features Importance graph """ ctx = dash.callback_context # Zoom is False by Default. It becomes True if we click on it @@ -1599,6 +1609,7 @@ def update_feature_importance(label, else: zoom_active = True selection = None + list_index = self.list_index selected_feature = self.explainer.inv_features_dict.get( clickData['points'][0]['label'].replace('', '').replace('', '') ) if clickData else None @@ -1612,7 +1623,7 @@ def update_feature_importance(label, self.label = label selection = [d['_index_'] for d in data] elif ctx.triggered[0]['prop_id'] == 'dataset.data': - self.list_index = [d['_index_'] for d in data] + list_index = [d['_index_'] for d in data] elif ctx.triggered[0]['prop_id'] == 'bool_groups.on': clickData = None # We reset the graph and clicks if we toggle the button # If we have selected data on prediction picking graph @@ -1635,7 +1646,7 @@ def update_feature_importance(label, elif ((ctx.triggered[0]['prop_id'] == 'prediction_picking.selectedData') and (selected_data is not None) and (len(selected_data) == 1)): # If there is some filters applied - if (len([d['_index_'] for d in data]) != len(self.list_index)): + if (len([d['_index_'] for d in data]) != len(list_index)): selection = [d['_index_'] for d in data] else: selection = None @@ -1643,7 +1654,7 @@ def update_feature_importance(label, elif ((ctx.triggered[0]['prop_id'] == 'prediction_picking.selectedData') and (selected_data is None)): # If there is some filters applied - if (len([d['_index_'] for d in data]) != len(self.list_index)): + if (len([d['_index_'] for d in data]) != len(list_index)): selection = [d['_index_'] for d in data] else: selection = None @@ -1689,23 +1700,24 @@ def update_feature_importance(label, for p in selected_data['points']: row_ids.append(p['customdata']) selection = row_ids - elif (len([d['_index_'] for d in data]) != len(self.list_index)): + elif (len([d['_index_'] for d in data]) != len(list_index)): selection = [d['_index_'] for d in data] else: selection = None # When we click twice on the same bar this will reset the graph - if self.last_click_data == clickData: + if clickData_store == clickData: selected_feature = None list_sub_features = [f for group_features in self.explainer.features_groups.values() for f in group_features] if selected_feature in list_sub_features: - self.last_click_data = clickData - raise PreventUpdate + for k, v in self.explainer.features_groups.items(): + if selected_feature in v: + selected_feature = k else: pass else: # Zoom management to generate graph which have global axis - if len(self.components['graph']['global_feature_importance'].figure['data']) == 1: + if len(figure['data']) == 1: selection = None else: row_ids = [] @@ -1717,40 +1729,37 @@ def update_feature_importance(label, else: # we plot filter subset selection = [d['_index_'] for d in data] - self.last_click_data = clickData if selection is not None and len(selection)==0: selection=None group_name = selected_feature if (self.explainer.features_groups is not None and selected_feature in self.explainer.features_groups.keys()) else None - self.components['graph']['global_feature_importance'].figure = \ - self.explainer.plot.features_importance( - max_features=features, - selection=selection, - label=self.label, - group_name=group_name, - display_groups=bool_group, - zoom=zoom_active - ) + figure = self.explainer.plot.features_importance( + max_features=features, + selection=selection, + label=self.label, + group_name=group_name, + display_groups=bool_group, + zoom=zoom_active + ) # Adjust graph with adding x axis title - self.components['graph']['global_feature_importance'].adjust_graph(x_ax='Mean absolute Contribution') - self.components['graph']['global_feature_importance'].figure.layout.clickmode = 'event+select' + MyGraph.adjust_graph_static(figure, x_ax='Mean absolute Contribution') + figure.layout.clickmode = 'event+select' if selected_feature: if self.explainer.features_groups is None: - self.select_point('global_feature_importance', clickData) + self.select_point(figure, clickData) elif selected_feature not in self.explainer.features_groups.keys(): - self.select_point('global_feature_importance', clickData) + self.select_point(figure, clickData) # font size can be adapted to screen size - nb_car = max([len(self.components['graph']['global_feature_importance'].figure.data[0].y[i]) for i in - range(len(self.components['graph']['global_feature_importance'].figure.data[0].y))]) - self.components['graph']['global_feature_importance'].figure.update_layout( + nb_car = max([len(figure.data[0].y[i]) for i in + range(len(figure.data[0].y))]) + figure.update_layout( yaxis=dict(tickfont={'size': min(round(500 / nb_car), 12)}) ) - - self.last_click_data = clickData - return self.components['graph']['global_feature_importance'].figure, clickData + clickData_store = clickData.copy() if clickData is not None else None + return figure, clickData, clickData_store @app.callback( Output(component_id='feature_selector', component_property='figure'), @@ -1767,7 +1776,8 @@ def update_feature_importance(label, ], [ State('points', 'value'), - State('violin', 'value') + State('violin', 'value'), + State('global_feature_importance', 'figure'), ] ) def update_feature_selector(feature, @@ -1780,7 +1790,8 @@ def update_feature_selector(feature, is_open, click_zoom, points, - violin): + violin, + gfi_figure): """ Update feature plot according to label, data, selected feature on features importance graph, @@ -1797,9 +1808,10 @@ def update_feature_selector(feature, click_zoom: click on zoom button points: points value in setting violin: violin value in setting + gfi_figure: figure of Features Importance graph --------------------------------------------- return - figure: feature selector graph + fs_figure: feature selector graph """ # Zoom is False by Default. It becomes True if we click on it click = 2 if click_zoom is None else click_zoom @@ -1807,6 +1819,12 @@ def update_feature_selector(feature, zoom_active = False else: zoom_active = True # To check if zoom is activated + subset = None + list_index = self.list_index + if feature is not None: + selected_feature = feature['points'][0]['label'].replace('', '').replace('', '') + else: + selected_feature = self.selected_feature ctx = dash.callback_context if ctx.triggered[0]['prop_id'] == 'modal.is_open': if is_open: @@ -1816,25 +1834,22 @@ def update_feature_selector(feature, self.settings_ini['points'] = self.settings['points'] self.settings['violin'] = violin self.settings_ini['violin'] = self.settings['violin'] - self.subset = None - elif ctx.triggered[0]['prop_id'] == 'select_label.value': - self.label = label elif ctx.triggered[0]['prop_id'] == 'global_feature_importance.clickData': if feature is not None: # Removing bold - self.selected_feature = feature['points'][0]['label'].replace('', '').replace('', '') + selected_feature = feature['points'][0]['label'].replace('', '').replace('', '') if feature['points'][0]['curveNumber'] == 0 and \ - len(self.components['graph']['global_feature_importance'].figure['data']) == 2: + len(gfi_figure['data']) == 2: if selected_data is not None and len(selected_data) > 1: row_ids = [] for p in selected_data['points']: row_ids.append(p['customdata']) - self.subset = row_ids + subset = row_ids else: - self.subset = [d['_index_'] for d in data] + subset = [d['_index_'] for d in data] else: - self.subset = self.list_index + subset = list_index # If we have selected data on prediction picking graph elif ((ctx.triggered[0]['prop_id'] == 'prediction_picking.selectedData') and (selected_data is not None)): @@ -1842,65 +1857,64 @@ def update_feature_selector(feature, if selected_data is not None and len(selected_data) > 1: for p in selected_data['points']: row_ids.append(p['customdata']) - self.subset = row_ids + subset = row_ids # if we have click on reset button elif ctx.triggered[0]['prop_id'] == 'reset_dropdown_button.n_clicks': - self.subset = None + subset = None # If we have clik on Apply filter button elif ctx.triggered[0]['prop_id'] == 'apply_filter.n_clicks': - self.subset = [d['_index_'] for d in data] + subset = [d['_index_'] for d in data] # If we have click on the last del button elif (('del_dropdown_button' in ctx.triggered[0]['prop_id']) & (None not in nclicks_del)): - self.subset = None + subset = None else: # Zoom management to generate graph which have global axis - if len(self.components['graph']['global_feature_importance'].figure['data']) == 1: - self.subset = self.list_index - elif (len(self.components['graph']['global_feature_importance'].figure['data']) == 2): + if len(gfi_figure['data']) == 1: + subset = list_index + elif (len(gfi_figure['data']) == 2): if feature is not None: if feature['points'][0]['curveNumber'] == 0: if selected_data is not None and len(selected_data) > 1: row_ids = [] for p in selected_data['points']: row_ids.append(p['customdata']) - self.subset = row_ids + subset = row_ids else: - self.subset = [d['_index_'] for d in data] + subset = [d['_index_'] for d in data] else: - self.subset = self.list_index + subset = list_index else: - self.subset = [d['_index_'] for d in data] + subset = [d['_index_'] for d in data] else: row_ids = [] if selected_data is not None and len(selected_data) > 1: # we plot prediction picking subset for p in selected_data['points']: row_ids.append(p['customdata']) - self.subset = row_ids + subset = row_ids else: # we plot filter subset - self.subset = [d['_index_'] for d in data] - subset = self.subset + subset = [d['_index_'] for d in data] if subset is not None and len(subset)==0: subset = None - self.components['graph']['feature_selector'].figure = \ - self.explainer.plot.contribution_plot( - col=self.selected_feature, + fs_figure = self.explainer.plot.contribution_plot( + col=selected_feature, selection=subset, - label=self.label, + label=label, violin_maxf=violin, max_points=points, zoom=zoom_active ) - self.components['graph']['feature_selector'].figure['layout'].clickmode = 'event+select' + fs_figure['layout'].clickmode = 'event+select' # Adjust graph with adding x and y axis titles - self.components['graph']['feature_selector'].adjust_graph( - x_ax=truncate_str(self.selected_feature, 110), + MyGraph.adjust_graph_static(fs_figure, + #x_ax=truncate_str(self.layout.selected_feature, 110), + x_ax=truncate_str(selected_feature, 110), y_ax='Contribution') - return self.components['graph']['feature_selector'].figure + return fs_figure @app.callback( [ @@ -1951,10 +1965,8 @@ def update_index_id(click_data, if ctx.triggered[0]['prop_id'] != 'dataset.data': if ctx.triggered[0]['prop_id'] == 'feature_selector.clickData': selected = click_data['points'][0]['customdata'][1] - self.click_graph = True elif ctx.triggered[0]['prop_id'] == 'prediction_picking.clickData': selected = prediction_picking['points'][0]['customdata'] - self.click_graph = True elif ctx.triggered[0]['prop_id'] == 'dataset.active_cell': if cell is not None: selected = data[cell['row']]['_index_'] @@ -1984,7 +1996,6 @@ def update_max_contrib_label(value): """ update max_contrib label """ - self.components['filter']['max_contrib']['max_contrib_id'].value = value return f'Features to display: {value}' @app.callback( @@ -1993,10 +2004,12 @@ def update_max_contrib_label(value): Output('max_contrib_id', 'marks') ], [Input('modal', 'is_open')], - [State('features', 'value')] + [State('features', 'value'), + State('max_contrib_id', 'value')] ) def update_max_contrib_id(is_open, - features): + features, + value): """ update max contrib component layout after settings modifications """ @@ -2019,7 +2032,7 @@ def update_max_contrib_id(is_open, marks = {f'{round(max * feat / nb_marks)}': f'{round(max * feat / nb_marks)}' for feat in range(1, nb_marks + 1)} marks['1'] = '1' - if max < self.components['filter']['max_contrib']['max_contrib_id'].value: + if max < value: value = max else: value = no_update @@ -2116,7 +2129,7 @@ def update_detail_feature(threshold, positive=sign, max_contrib=max_contrib, display_groups=bool_group) - self.components['graph']['detail_feature'].figure = self.explainer.plot.local_plot( + figure = self.explainer.plot.local_plot( index=selected, label=label, show_masked=True, @@ -2126,18 +2139,18 @@ def update_detail_feature(threshold, ) if selected is not None: # Adjust graph with adding x axis titles - self.components['graph']['detail_feature'].adjust_graph(x_ax='Contribution') + MyGraph.adjust_graph_static(figure, x_ax='Contribution') # font size can be adapted to screen size - list_yaxis = [self.components['graph']['detail_feature'].figure.data[i].y[0] for i in - range(len(self.components['graph']['detail_feature'].figure.data))] + list_yaxis = [figure.data[i].y[0] for i in + range(len(figure.data))] # exclude new line with labels of y axis if list_yaxis != []: list_yaxis = [x.split('
')[0] for x in list_yaxis] nb_car = max([len(x) for x in list_yaxis]) - self.components['graph']['detail_feature'].figure.update_layout( + figure.update_layout( yaxis=dict(tickfont={'size': min(round(500 / nb_car), 12)}) ) - return self.components['graph']['detail_feature'].figure + return figure @app.callback( Output("validation", "n_clicks"), @@ -2321,7 +2334,6 @@ def datatable_layout(validation, @app.callback( Output(component_id='prediction_picking', component_property='figure'), [ - Input('global_feature_importance', 'clickData'), Input('dataset', 'data'), Input('apply_filter', 'n_clicks'), Input('reset_dropdown_button', 'n_clicks'), @@ -2335,8 +2347,7 @@ def datatable_layout(validation, State('violin', 'value') ] ) - def update_prediction_picking(feature, - data, + def update_prediction_picking(data, apply_filters, reset_filter, nclicks_del, @@ -2349,7 +2360,6 @@ def update_prediction_picking(feature, Update feature plot according to label, data, selected feature and settings modifications ------------------------------------------------ - feature: click on features importance graph data: the dataset apply_filters: click on apply filter button reset_filter: click on reset filter button @@ -2365,7 +2375,7 @@ def update_prediction_picking(feature, """ ctx = dash.callback_context # Filter subset - filter_subset = None + subset = None if not ctx.triggered: raise dash.exceptions.PreventUpdate if ctx.triggered[0]['prop_id'] == 'modal.is_open': @@ -2376,35 +2386,34 @@ def update_prediction_picking(feature, self.settings_ini['points'] = self.settings['points'] self.settings['violin'] = violin self.settings_ini['violin'] = self.settings['violin'] - self.subset = None elif ctx.triggered[0]['prop_id'] == 'select_label.value': self.label = label # If we have clicked on reset button elif ctx.triggered[0]['prop_id'] == 'reset_dropdown_button.n_clicks': - self.subset = None + subset = None # If we have clicked on Apply filter button elif ctx.triggered[0]['prop_id'] == 'apply_filter.n_clicks': - self.subset = [d['_index_'] for d in data] + subset = [d['_index_'] for d in data] # If we have clicked on the last delete button (X) elif (('del_dropdown_button' in ctx.triggered[0]['prop_id']) & (None not in nclicks_del)): - self.subset = None + subset = None else: raise PreventUpdate - self.components['graph']['prediction_picking'].figure = self.explainer.plot.scatter_plot_prediction( + figure = self.explainer.plot.scatter_plot_prediction( selection=self.subset, max_points=points, label=self.label ) if self.explainer.y_target is not None: - self.components['graph']['prediction_picking'].figure['layout'].clickmode = 'event+select' + figure['layout'].clickmode = 'event+select' # Adjust graph with adding x and y axis titles - self.components['graph']['prediction_picking'].adjust_graph( + MyGraph.adjust_graph_static(figure, x_ax="True Values", y_ax="Predicted Values") - return self.components['graph']['prediction_picking'].figure + return figure @app.callback( Output("modal_feature_importance", "is_open"), diff --git a/shapash/webapp/utils/MyGraph.py b/shapash/webapp/utils/MyGraph.py index 34aad805..10e43612 100644 --- a/shapash/webapp/utils/MyGraph.py +++ b/shapash/webapp/utils/MyGraph.py @@ -43,7 +43,8 @@ def __init__(self, figure, id, style={}, **kwds): 'displaylogo': False } - def adjust_graph(self, + @staticmethod + def adjust_graph_static(figure, x_ax="", y_ax=""): """ @@ -53,8 +54,8 @@ def adjust_graph(self, y_ax: title of the y-axis --------------------------------------- """ - new_title = update_title(self.figure.layout.title.text) - self.figure.update_layout( + new_title = update_title(figure.layout.title.text) + figure.update_layout( autosize=True, margin=dict( l=50, @@ -75,11 +76,11 @@ def adjust_graph(self, } ) # update x title and font-size of the title - self.figure.update_xaxes(title='' + x_ax + '', + figure.update_xaxes(title='' + x_ax + '', automargin=True ) # update y title and font-size of the title - self.figure.update_yaxes(title='' + y_ax + '', + figure.update_yaxes(title='' + y_ax + '', automargin=True ) From 716d43d09e41840bb037d704244ea49e7fbf0a9f Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Tue, 9 May 2023 14:34:07 +0200 Subject: [PATCH 02/12] Remove remaining self modifications. --- shapash/webapp/smart_app.py | 64 ++++++++++--------------------------- 1 file changed, 16 insertions(+), 48 deletions(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index b36970c8..f3bbaf0d 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -139,7 +139,7 @@ def __init__(self, explainer, settings: dict = None): self.init_callback_settings() self.callback_generator() - def init_data(self): + def init_data(self, rows = None): """ Method which initializes data from explainer object """ @@ -174,10 +174,12 @@ def init_data(self): col_order = self.special_cols + self.dataframe.columns.drop(self.special_cols).tolist() random.seed(79) + if rows is None: + rows = self.settings['rows'] self.list_index = \ random.sample( population=self.dataframe.index.tolist(), - k=min(self.settings['rows'], len(self.dataframe.index.tolist())) + k=min(rows, len(self.dataframe.index.tolist())) ) self.dataframe = self.dataframe[col_order].loc[self.list_index].sort_index() self.round_dataframe = self.dataframe.copy() @@ -1438,9 +1440,7 @@ def update_datatable(selected_data, if is_open: raise PreventUpdate else: - self.settings['rows'] = rows - self.init_data() - self.settings_ini['rows'] = self.settings['rows'] + self.init_data(rows = rows) if name == [1]: columns = [{"name": i, "id": i} for i in self.special_cols] + \ [{"name": self.features_dict[i], "id": i} for i in self.dataframe.columns.drop(self.special_cols)] @@ -1551,7 +1551,6 @@ def update_datatable(selected_data, Input('apply_filter', 'n_clicks'), Input('reset_dropdown_button', 'n_clicks'), Input({'type': 'del_dropdown_button', 'index': ALL}, 'n_clicks'), - Input('modal', 'is_open'), Input('card_global_feature_importance', 'n_clicks'), Input('bool_groups', 'on'), Input('ember_global_feature_importance', 'n_clicks') @@ -1569,7 +1568,6 @@ def update_feature_importance(label, apply_filters, reset_filter, nclicks_del, - is_open, n_clicks, bool_group, click_zoom, @@ -1613,14 +1611,7 @@ def update_feature_importance(label, selected_feature = self.explainer.inv_features_dict.get( clickData['points'][0]['label'].replace('', '').replace('', '') ) if clickData else None - if ctx.triggered[0]['prop_id'] == 'modal.is_open': - if is_open: - raise PreventUpdate - else: - self.settings['features'] = features - self.settings_ini['features'] = self.settings['features'] - elif ctx.triggered[0]['prop_id'] == 'select_label.value': - self.label = label + if ctx.triggered[0]['prop_id'] == 'select_label.value': selection = [d['_index_'] for d in data] elif ctx.triggered[0]['prop_id'] == 'dataset.data': list_index = [d['_index_'] for d in data] @@ -1738,7 +1729,7 @@ def update_feature_importance(label, figure = self.explainer.plot.features_importance( max_features=features, selection=selection, - label=self.label, + label=label, group_name=group_name, display_groups=bool_group, zoom=zoom_active @@ -1771,7 +1762,6 @@ def update_feature_importance(label, Input('reset_dropdown_button', 'n_clicks'), Input({'type': 'del_dropdown_button', 'index': ALL}, 'n_clicks'), Input('select_label', 'value'), - Input('modal', 'is_open'), Input('ember_feature_selector', 'n_clicks') ], [ @@ -1787,7 +1777,6 @@ def update_feature_selector(feature, reset_filter, nclicks_del, label, - is_open, click_zoom, points, violin, @@ -1826,16 +1815,7 @@ def update_feature_selector(feature, else: selected_feature = self.selected_feature ctx = dash.callback_context - if ctx.triggered[0]['prop_id'] == 'modal.is_open': - if is_open: - raise PreventUpdate - else: - self.settings['points'] = points - self.settings_ini['points'] = self.settings['points'] - self.settings['violin'] = violin - self.settings_ini['violin'] = self.settings['violin'] - - elif ctx.triggered[0]['prop_id'] == 'global_feature_importance.clickData': + if ctx.triggered[0]['prop_id'] == 'global_feature_importance.clickData': if feature is not None: # Removing bold selected_feature = feature['points'][0]['label'].replace('', '').replace('', '') @@ -2339,7 +2319,6 @@ def datatable_layout(validation, Input('reset_dropdown_button', 'n_clicks'), Input({'type': 'del_dropdown_button', 'index': ALL}, 'n_clicks'), Input('select_label', 'value'), - Input('modal', 'is_open'), Input('ember_prediction_picking', 'n_clicks') ], [ @@ -2352,7 +2331,6 @@ def update_prediction_picking(data, reset_filter, nclicks_del, label, - is_open, click_zoom, points, violin): @@ -2378,33 +2356,23 @@ def update_prediction_picking(data, subset = None if not ctx.triggered: raise dash.exceptions.PreventUpdate - if ctx.triggered[0]['prop_id'] == 'modal.is_open': - if is_open: - raise PreventUpdate - else: - self.settings['points'] = points - self.settings_ini['points'] = self.settings['points'] - self.settings['violin'] = violin - self.settings_ini['violin'] = self.settings['violin'] - elif ctx.triggered[0]['prop_id'] == 'select_label.value': - self.label = label - # If we have clicked on reset button - elif ctx.triggered[0]['prop_id'] == 'reset_dropdown_button.n_clicks': - subset = None - # If we have clicked on Apply filter button - elif ctx.triggered[0]['prop_id'] == 'apply_filter.n_clicks': + if ctx.triggered[0]['prop_id'] == 'apply_filter.n_clicks': subset = [d['_index_'] for d in data] # If we have clicked on the last delete button (X) elif (('del_dropdown_button' in ctx.triggered[0]['prop_id']) & (None not in nclicks_del)): subset = None - else: + elif ctx.triggered[0]['prop_id'] not in [ + 'dataset.data', + 'select_label.value', + 'reset_dropdown_button.n_clicks' + ]: raise PreventUpdate figure = self.explainer.plot.scatter_plot_prediction( - selection=self.subset, + selection=subset, max_points=points, - label=self.label + label=label ) if self.explainer.y_target is not None: figure['layout'].clickmode = 'event+select' From ee0279402df815e88563b859a0b201dc623ec9c6 Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Tue, 9 May 2023 15:23:53 +0200 Subject: [PATCH 03/12] Fix prediciton graph with self modifications. --- shapash/webapp/smart_app.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index f3bbaf0d..9b6f2eec 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -1403,7 +1403,6 @@ def update_datatable(selected_data, filtering and settings modifications. ------------------------------------------------------------------ selected_data: selected data in prediction picking graph - is_open: modal nclicks_apply: click on Apply Filter button nclicks_reset: click on Reset All Filter button nclicks_del: click on delete button @@ -1585,7 +1584,6 @@ def update_feature_importance(label, apply_filters: click on apply filter button reset_filter: click on reset filter button nclicks_del: click on del button - is_open: modal n_clicks: click on features importance card bool_group: display groups click_zoom: click on zoom button @@ -1793,7 +1791,6 @@ def update_feature_selector(feature, reset_filter: click on reset filter button nclicks_del: click del button label: selected label - is_open: modal click_zoom: click on zoom button points: points value in setting violin: violin value in setting @@ -2319,6 +2316,7 @@ def datatable_layout(validation, Input('reset_dropdown_button', 'n_clicks'), Input({'type': 'del_dropdown_button', 'index': ALL}, 'n_clicks'), Input('select_label', 'value'), + Input('modal', 'is_open'), Input('ember_prediction_picking', 'n_clicks') ], [ @@ -2331,6 +2329,7 @@ def update_prediction_picking(data, reset_filter, nclicks_del, label, + is_open, click_zoom, points, violin): @@ -2363,7 +2362,7 @@ def update_prediction_picking(data, (None not in nclicks_del)): subset = None elif ctx.triggered[0]['prop_id'] not in [ - 'dataset.data', + 'modal.is_open', 'select_label.value', 'reset_dropdown_button.n_clicks' ]: From 830f52df2e1d49b1df535506dde79b8af92d57b2 Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Wed, 10 May 2023 15:49:45 +0200 Subject: [PATCH 04/12] Fix prediction plot behavior without self. --- shapash/webapp/smart_app.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index 9b6f2eec..dfe44718 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -2309,7 +2309,8 @@ def datatable_layout(validation, return style_data_conditional, style_filter_conditional, style_header_conditional, style_cell_conditional @app.callback( - Output(component_id='prediction_picking', component_property='figure'), + Output('prediction_picking', 'figure'), + Output('prediction_picking', 'selectedData'), [ Input('dataset', 'data'), Input('apply_filter', 'n_clicks'), @@ -2321,7 +2322,8 @@ def datatable_layout(validation, ], [ State('points', 'value'), - State('violin', 'value') + State('violin', 'value'), + State('prediction_picking','selectedData') ] ) def update_prediction_picking(data, @@ -2332,7 +2334,8 @@ def update_prediction_picking(data, is_open, click_zoom, points, - violin): + violin, + selectedData): """ Update feature plot according to label, data, selected feature and settings modifications @@ -2353,20 +2356,14 @@ def update_prediction_picking(data, ctx = dash.callback_context # Filter subset subset = None - if not ctx.triggered: - raise dash.exceptions.PreventUpdate - if ctx.triggered[0]['prop_id'] == 'apply_filter.n_clicks': - subset = [d['_index_'] for d in data] - # If we have clicked on the last delete button (X) - elif (('del_dropdown_button' in ctx.triggered[0]['prop_id']) & - (None not in nclicks_del)): - subset = None - elif ctx.triggered[0]['prop_id'] not in [ - 'modal.is_open', - 'select_label.value', - 'reset_dropdown_button.n_clicks' - ]: + if (ctx.triggered[0]['prop_id']=='reset_dropdown_button.n_clicks' or + ('del_dropdown_button' in ctx.triggered[0]['prop_id'] and + None not in nclicks_del)): + selectedData = None + if selectedData is not None and len(selectedData['points'])>0: raise PreventUpdate + else: + subset = [d['_index_'] for d in data] figure = self.explainer.plot.scatter_plot_prediction( selection=subset, @@ -2380,7 +2377,7 @@ def update_prediction_picking(data, x_ax="True Values", y_ax="Predicted Values") - return figure + return figure, selectedData @app.callback( Output("modal_feature_importance", "is_open"), From 79f0f56fd8391cc9f22ef1b192021b1ffc1ead42 Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Wed, 10 May 2023 16:18:49 +0200 Subject: [PATCH 05/12] Refacto subset for graphs. --- shapash/webapp/smart_app.py | 215 ++++++------------------------------ 1 file changed, 31 insertions(+), 184 deletions(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index dfe44718..dd8408f8 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -1606,119 +1606,37 @@ def update_feature_importance(label, zoom_active = True selection = None list_index = self.list_index + if clickData is not None and (ctx.triggered[0]['prop_id'] in [ + 'apply_filter.n_clicks', + 'reset_dropdown_button.n_clicks', + 'dataset.data' + ] or ('del_dropdown_button' in ctx.triggered[0]['prop_id'] and + None not in nclicks_del)): + point = clickData['points'][0] + point['curveNumber'] = 0 + clickData = {'points':[point]} + selected_feature = self.explainer.inv_features_dict.get( clickData['points'][0]['label'].replace('', '').replace('', '') ) if clickData else None - if ctx.triggered[0]['prop_id'] == 'select_label.value': - selection = [d['_index_'] for d in data] - elif ctx.triggered[0]['prop_id'] == 'dataset.data': - list_index = [d['_index_'] for d in data] - elif ctx.triggered[0]['prop_id'] == 'bool_groups.on': - clickData = None # We reset the graph and clicks if we toggle the button - # If we have selected data on prediction picking graph - elif ((ctx.triggered[0]['prop_id'] == 'prediction_picking.selectedData') and - (selected_data is not None) and (len(selected_data) > 1)): - row_ids = [] - if selected_data is not None and len(selected_data) > 1: - for p in selected_data['points']: - row_ids.append(p['customdata']) - selection = row_ids - else: - selection = None - #when group - if self.explainer.features_groups and bool_group: - list_sub_features = [f for group_features in self.explainer.features_groups.values() - for f in group_features] - if selected_feature not in list_sub_features: - selected_feature = None - # If click on a single point on prediction picking, do nothing - elif ((ctx.triggered[0]['prop_id'] == 'prediction_picking.selectedData') and - (selected_data is not None) and (len(selected_data) == 1)): - # If there is some filters applied - if (len([d['_index_'] for d in data]) != len(list_index)): - selection = [d['_index_'] for d in data] - else: - selection = None - # If we have dubble click on prediction picking to remove the selected subset - elif ((ctx.triggered[0]['prop_id'] == 'prediction_picking.selectedData') and - (selected_data is None)): - # If there is some filters applied - if (len([d['_index_'] for d in data]) != len(list_index)): - selection = [d['_index_'] for d in data] - else: - selection = None - #when group - if self.explainer.features_groups and bool_group: - list_sub_features = [f for group_features in self.explainer.features_groups.values() - for f in group_features] - if selected_feature not in list_sub_features: - selected_feature = None - # If we click on reset filter button - elif ctx.triggered[0]['prop_id'] == 'reset_dropdown_button.n_clicks': - selection = None - #when group - if self.explainer.features_groups and bool_group: - list_sub_features = [f for group_features in self.explainer.features_groups.values() - for f in group_features] - if selected_feature not in list_sub_features: - selected_feature = None - # If we click on Apply button - elif ctx.triggered[0]['prop_id'] == 'apply_filter.n_clicks': - selection = [d['_index_'] for d in data] - #when group - if self.explainer.features_groups and bool_group: - list_sub_features = [f for group_features in self.explainer.features_groups.values() - for f in group_features] - if selected_feature not in list_sub_features: - selected_feature = None - # If we click on the last del button - elif (('del_dropdown_button' in ctx.triggered[0]['prop_id']) & - (None not in nclicks_del)): - selection = None - #when group - if self.explainer.features_groups and bool_group: - list_sub_features = [f for group_features in self.explainer.features_groups.values() - for f in group_features] - if selected_feature not in list_sub_features: + + if self.explainer.features_groups and bool_group: + list_sub_features = [f for group_features in self.explainer.features_groups.values() + for f in group_features] + if ctx.triggered[0]['prop_id'] == 'card_global_feature_importance.n_clicks': + # When we click twice on the same bar this will reset the graph + if clickData_store == clickData: selected_feature = None - elif (ctx.triggered[0]['prop_id'] == 'card_global_feature_importance.n_clicks' - and self.explainer.features_groups and bool_group): - row_ids = [] - if selected_data is not None and len(selected_data) > 1: - # we plot prediction picking subset - for p in selected_data['points']: - row_ids.append(p['customdata']) - selection = row_ids - elif (len([d['_index_'] for d in data]) != len(list_index)): - selection = [d['_index_'] for d in data] + if selected_feature in list_sub_features: + for k, v in self.explainer.features_groups.items(): + if selected_feature in v: + selected_feature = k else: - selection = None - # When we click twice on the same bar this will reset the graph - if clickData_store == clickData: selected_feature = None - list_sub_features = [f for group_features in self.explainer.features_groups.values() - for f in group_features] - if selected_feature in list_sub_features: - for k, v in self.explainer.features_groups.items(): - if selected_feature in v: - selected_feature = k - else: - pass - else: - # Zoom management to generate graph which have global axis - if len(figure['data']) == 1: - selection = None - else: - row_ids = [] - if selected_data is not None and len(selected_data) > 1: - # we plot prediction picking subset - for p in selected_data['points']: - row_ids.append(p['customdata']) - selection = row_ids - else: - # we plot filter subset - selection = [d['_index_'] for d in data] - if selection is not None and len(selection)==0: + + + selection = [d['_index_'] for d in data] + if len(selection)==len(list_index) or len(selection)==0: selection=None group_name = selected_feature if (self.explainer.features_groups is not None @@ -1736,10 +1654,7 @@ def update_feature_importance(label, MyGraph.adjust_graph_static(figure, x_ax='Mean absolute Contribution') figure.layout.clickmode = 'event+select' if selected_feature: - if self.explainer.features_groups is None: - self.select_point(figure, clickData) - elif selected_feature not in self.explainer.features_groups.keys(): - self.select_point(figure, clickData) + self.select_point(figure, clickData) # font size can be adapted to screen size nb_car = max([len(figure.data[0].y[i]) for i in @@ -1754,11 +1669,7 @@ def update_feature_importance(label, Output(component_id='feature_selector', component_property='figure'), [ Input('global_feature_importance', 'clickData'), - Input('prediction_picking', 'selectedData'), Input('dataset', 'data'), - Input('apply_filter', 'n_clicks'), - Input('reset_dropdown_button', 'n_clicks'), - Input({'type': 'del_dropdown_button', 'index': ALL}, 'n_clicks'), Input('select_label', 'value'), Input('ember_feature_selector', 'n_clicks') ], @@ -1769,11 +1680,7 @@ def update_feature_importance(label, ] ) def update_feature_selector(feature, - selected_data, data, - apply_filters, - reset_filter, - nclicks_del, label, click_zoom, points, @@ -1785,11 +1692,7 @@ def update_feature_selector(feature, filters and settings modifications -------------------------------------------- feature: click on feature importance graph - selected_data: Data selected on prediction picking graph data: dataset - apply_filters: click on apply filter button - reset_filter: click on reset filter button - nclicks_del: click del button label: selected label click_zoom: click on zoom button points: points value in setting @@ -1811,69 +1714,13 @@ def update_feature_selector(feature, selected_feature = feature['points'][0]['label'].replace('', '').replace('', '') else: selected_feature = self.selected_feature - ctx = dash.callback_context - if ctx.triggered[0]['prop_id'] == 'global_feature_importance.clickData': - if feature is not None: - # Removing bold - selected_feature = feature['points'][0]['label'].replace('', '').replace('', '') - if feature['points'][0]['curveNumber'] == 0 and \ - len(gfi_figure['data']) == 2: - if selected_data is not None and len(selected_data) > 1: - row_ids = [] - for p in selected_data['points']: - row_ids.append(p['customdata']) - subset = row_ids - else: - subset = [d['_index_'] for d in data] - else: - subset = list_index - # If we have selected data on prediction picking graph - elif ((ctx.triggered[0]['prop_id'] == 'prediction_picking.selectedData') and - (selected_data is not None)): - row_ids = [] - if selected_data is not None and len(selected_data) > 1: - for p in selected_data['points']: - row_ids.append(p['customdata']) - subset = row_ids - # if we have click on reset button - elif ctx.triggered[0]['prop_id'] == 'reset_dropdown_button.n_clicks': - subset = None - # If we have clik on Apply filter button - elif ctx.triggered[0]['prop_id'] == 'apply_filter.n_clicks': + + if feature is not None and feature['points'][0]['curveNumber'] == 0 and \ + len(gfi_figure['data']) == 2: subset = [d['_index_'] for d in data] - # If we have click on the last del button - elif (('del_dropdown_button' in ctx.triggered[0]['prop_id']) & - (None not in nclicks_del)): - subset = None else: - # Zoom management to generate graph which have global axis - if len(gfi_figure['data']) == 1: - subset = list_index - elif (len(gfi_figure['data']) == 2): - if feature is not None: - if feature['points'][0]['curveNumber'] == 0: - if selected_data is not None and len(selected_data) > 1: - row_ids = [] - for p in selected_data['points']: - row_ids.append(p['customdata']) - subset = row_ids - else: - subset = [d['_index_'] for d in data] - else: - subset = list_index - else: - subset = [d['_index_'] for d in data] - else: - row_ids = [] - if selected_data is not None and len(selected_data) > 1: - # we plot prediction picking subset - for p in selected_data['points']: - row_ids.append(p['customdata']) - subset = row_ids - else: - # we plot filter subset - subset = [d['_index_'] for d in data] - if subset is not None and len(subset)==0: + subset = None + if subset is not None and (len(subset)==len(list_index) or len(subset)==0): subset = None fs_figure = self.explainer.plot.contribution_plot( From fd8381400d61a58707e09b49d2f553fbdc0b8ec0 Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Fri, 9 Jun 2023 15:35:47 +0200 Subject: [PATCH 06/12] Fix apply filter when selectedData. --- shapash/webapp/smart_app.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index dfe44718..95ff4014 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -2356,7 +2356,8 @@ def update_prediction_picking(data, ctx = dash.callback_context # Filter subset subset = None - if (ctx.triggered[0]['prop_id']=='reset_dropdown_button.n_clicks' or + if (ctx.triggered[0]['prop_id'] == 'apply_filter.n_clicks' or + ctx.triggered[0]['prop_id'] == 'reset_dropdown_button.n_clicks' or ('del_dropdown_button' in ctx.triggered[0]['prop_id'] and None not in nclicks_del)): selectedData = None From adda4e488347749827101952590c3354aaf36e9e Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Mon, 12 Jun 2023 17:00:10 +0200 Subject: [PATCH 07/12] Refacto dataset callback. --- shapash/webapp/smart_app.py | 77 ++--- shapash/webapp/utils/callbacks.py | 68 +++++ .../unit_tests/webapp/utils/test_callbacks.py | 282 ++++++++++++++++++ 3 files changed, 370 insertions(+), 57 deletions(-) create mode 100644 shapash/webapp/utils/callbacks.py create mode 100644 tests/unit_tests/webapp/utils/test_callbacks.py diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index 15bb1dd9..029914f4 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -24,6 +24,10 @@ from shapash.webapp.utils.MyGraph import MyGraph from shapash.utils.utils import truncate_str from shapash.webapp.utils.explanations import Explanations +from shapash.webapp.utils.callbacks import ( + select_data_from_prediction_picking, + select_data_from_filters, +) def create_input_modal(id, label, tooltip): @@ -1446,14 +1450,7 @@ def update_datatable(selected_data, df = self.round_dataframe elif ((ctx.triggered[0]['prop_id'] == 'prediction_picking.selectedData') and (selected_data is not None) and (len(selected_data) > 1)): - row_ids = [] - # If some data have been selected in prediction picking graph - if selected_data is not None and len(selected_data) > 1: - for p in selected_data['points']: - row_ids.append(p['customdata']) - df = self.round_dataframe.loc[row_ids] - else: - df = self.round_dataframe + df = select_data_from_prediction_picking(self.round_dataframe, selected_data) # If click on reset button elif ctx.triggered[0]['prop_id'] == 'reset_dropdown_button.n_clicks': df = self.round_dataframe @@ -1465,55 +1462,21 @@ def update_datatable(selected_data, (selected_data is not None and len(selected_data) == 1 and len(selected_data['points'])!=0 and selected_data['points'][0]['curveNumber'] > 0) )): - # get list of ID - feature_id = [id_feature[i]['index'] for i in range(len(id_feature))] - str_id = [id_str_modality[i]['index'] for i in range(len(id_str_modality))] - bool_id = [id_bool_modality[i]['index'] for i in range(len(id_bool_modality))] - lower_id = [id_lower_modality[i]['index'] for i in range(len(id_lower_modality))] - date_id = [id_date[i]['index'] for i in range(len(id_date))] - df = self.round_dataframe - # If there is some filters - if len(feature_id) > 0: - for i in range(len(feature_id)): - # String filter - if feature_id[i] in str_id: - position = np.where(np.array(str_id) == feature_id[i])[0][0] - if ((position is not None) & (val_str_modality[position] is not None)): - df = df[df[val_feature[i]].isin(val_str_modality[position])] - else: - df = df - # Boolean filter - elif feature_id[i] in bool_id: - position = np.where(np.array(bool_id) == feature_id[i])[0][0] - if ((position is not None) & (val_bool_modality[position] is not None)): - df = df[df[val_feature[i]] == val_bool_modality[position]] - else: - df = df - # Date filter - elif feature_id[i] in date_id: - position = np.where(np.array(date_id) == feature_id[i])[0][0] - if((position is not None) & - (start_date[position] < end_date[position])): - df = df[((df[val_feature[i]] >= start_date[position]) & - (df[val_feature[i]] <= end_date[position]))] - else: - df = df - # Numeric filter - elif feature_id[i] in lower_id: - position = np.where(np.array(lower_id) == feature_id[i])[0][0] - if((position is not None) & (val_lower_modality[position] is not None) & - (val_upper_modality[position] is not None)): - if (val_lower_modality[position] < val_upper_modality[position]): - df = df[(df[val_feature[i]] >= val_lower_modality[position]) & - (df[val_feature[i]] <= val_upper_modality[position])] - else: - df = df - else: - df = df - else: - df = df - else: - df = df + df = select_data_from_filters( + self.round_dataframe, + id_feature, + id_str_modality, + id_bool_modality, + id_lower_modality, + id_date, + val_feature, + val_str_modality, + val_bool_modality, + val_lower_modality, + val_upper_modality, + start_date, + end_date, + ) filtered_subset_info = f"Subset length: {len(df)} ({int(round(100*len(df)/self.explainer.x_init.shape[0]))}%)" if len(df)==0: filtered_subset_color="danger" diff --git a/shapash/webapp/utils/callbacks.py b/shapash/webapp/utils/callbacks.py new file mode 100644 index 00000000..5444b3ae --- /dev/null +++ b/shapash/webapp/utils/callbacks.py @@ -0,0 +1,68 @@ +from typing import Optional, Tuple + +import numpy as np +import pandas as pd + + + +def select_data_from_prediction_picking(round_dataframe: pd.DataFrame, selected_data: dict) -> pd.DataFrame: + row_ids = [] + for p in selected_data['points']: + row_ids.append(p['customdata']) + df = round_dataframe.loc[row_ids] + + return df + + +def select_data_from_filters( + round_dataframe: pd.DataFrame, + id_feature: list, + id_str_modality: list, + id_bool_modality: list, + id_lower_modality: list, + id_date: list, + val_feature: list, + val_str_modality: list, + val_bool_modality: list, + val_lower_modality: list, + val_upper_modality: list, + start_date: list, + end_date: list, +) -> pd.DataFrame: + # get list of ID + feature_id = [id_feature[i]['index'] for i in range(len(id_feature))] + str_id = [id_str_modality[i]['index'] for i in range(len(id_str_modality))] + bool_id = [id_bool_modality[i]['index'] for i in range(len(id_bool_modality))] + lower_id = [id_lower_modality[i]['index'] for i in range(len(id_lower_modality))] + date_id = [id_date[i]['index'] for i in range(len(id_date))] + df = round_dataframe + # If there is some filters + if len(feature_id) > 0: + for i in range(len(feature_id)): + # String filter + if feature_id[i] in str_id: + position = np.where(np.array(str_id) == feature_id[i])[0][0] + if ((position is not None) & (val_str_modality[position] is not None)): + df = df[df[val_feature[i]].isin(val_str_modality[position])] + # Boolean filter + elif feature_id[i] in bool_id: + position = np.where(np.array(bool_id) == feature_id[i])[0][0] + if ((position is not None) & (val_bool_modality[position] is not None)): + df = df[df[val_feature[i]] == val_bool_modality[position]] + # Date filter + elif feature_id[i] in date_id: + position = np.where(np.array(date_id) == feature_id[i])[0][0] + if((position is not None) & + (start_date[position] < end_date[position])): + df = df[((df[val_feature[i]] >= start_date[position]) & + (df[val_feature[i]] <= end_date[position]))] + # Numeric filter + elif feature_id[i] in lower_id: + position = np.where(np.array(lower_id) == feature_id[i])[0][0] + if((position is not None) & (val_lower_modality[position] is not None) & + (val_upper_modality[position] is not None)): + if (val_lower_modality[position] < val_upper_modality[position]): + df = df[(df[val_feature[i]] >= val_lower_modality[position]) & + (df[val_feature[i]] <= val_upper_modality[position])] + + return df \ No newline at end of file diff --git a/tests/unit_tests/webapp/utils/test_callbacks.py b/tests/unit_tests/webapp/utils/test_callbacks.py new file mode 100644 index 00000000..066fc299 --- /dev/null +++ b/tests/unit_tests/webapp/utils/test_callbacks.py @@ -0,0 +1,282 @@ +import unittest +import numpy as np +import pandas as pd +from sklearn.tree import DecisionTreeClassifier +from shapash import SmartExplainer +from shapash.webapp.smart_app import SmartApp +from shapash.webapp.utils.callbacks import select_data_from_prediction_picking, select_data_from_filters + + +class TestCallbacks(unittest.TestCase): + + def __init__(self, *args, **kwargs): + data = { + 'column1': [1, 2, 3, 4, 5], + 'column2': ['a', 'b', 'c', 'd', 'e'], + 'column3': [1.1, 3.3, 2.2, 4.4, 5.5], + 'column4': [True, False, True, False, False], + 'column5': pd.date_range('2023-01-01', periods=5), + } + + df = pd.DataFrame(data) + self.df = df + + dataframe_x = df[['column1','column3']].copy() + y_target = pd.DataFrame(data=np.array([1, 2, 3, 4, 5]), columns=['pred']) + model = DecisionTreeClassifier().fit(dataframe_x, y_target) + features_dict = {'column3': 'Useless col'} + additional_data = df[['column2']].copy() + additional_features_dict = {'column2': 'Additional col'} + self.xpl = SmartExplainer(model=model, features_dict=features_dict) + self.xpl.compile( + x=dataframe_x, + y_pred=y_target, + y_target=y_target, + additional_data=additional_data, + additional_features_dict=additional_features_dict + ) + self.smart_app = SmartApp(self.xpl) + super(TestCallbacks, self).__init__(*args, **kwargs) + + def test_default_init_data(self): + expected_result = pd.DataFrame( + { + '_index_': [0, 1, 2, 3, 4], + '_predict_': [1, 2, 3, 4, 5], + '_target_': [1, 2, 3, 4, 5], + 'column1': [1, 2, 3, 4, 5], + 'column3': [1.1, 3.3, 2.2, 4.4, 5.5], + '_column2': ['a', 'b', 'c', 'd', 'e'], + }, + ) + self.smart_app.init_data() + pd.testing.assert_frame_equal(expected_result, self.smart_app.round_dataframe) + + def test_limited_rows_init_data(self): + self.smart_app.init_data(3) + assert len(self.smart_app.round_dataframe)==3 + + def test_select_data_from_prediction_picking(self): + selected_data = {"points": [{"customdata":0}, {"customdata":2}]} + expected_result = pd.DataFrame( + { + 'column1': [1, 3], + 'column2': ['a', 'c'], + 'column3': [1.1, 2.2], + 'column4': [True, True], + 'column5': [pd.Timestamp('2023-01-01'), pd.Timestamp('2023-01-03')], + }, + index=[0, 2] + ) + result = select_data_from_prediction_picking(self.df, selected_data) + pd.testing.assert_frame_equal(expected_result, result) + + def test_select_data_from_filters_string(self): + round_dataframe = self.df + id_feature = [{'type': 'var_dropdown', 'index': 1}] + id_str_modality = [{'type': 'dynamic-str', 'index': 1}] + id_bool_modality = [] + id_lower_modality = [] + id_date = [] + val_feature = ['column2'] + val_str_modality = [['a', 'c']] + val_bool_modality = [] + val_lower_modality = [] + val_upper_modality = [] + start_date = [] + end_date = [] + + expected_result = pd.DataFrame( + { + 'column1': [1, 3], + 'column2': ['a', 'c'], + 'column3': [1.1, 2.2], + 'column4': [True, True], + 'column5': [pd.Timestamp('2023-01-01'), pd.Timestamp('2023-01-03')], + }, + index=[0, 2] + ) + result = select_data_from_filters( + round_dataframe, + id_feature, + id_str_modality, + id_bool_modality, + id_lower_modality, + id_date, + val_feature, + val_str_modality, + val_bool_modality, + val_lower_modality, + val_upper_modality, + start_date, + end_date, + ) + pd.testing.assert_frame_equal(expected_result, result) + + def test_select_data_from_filters_bool(self): + round_dataframe = self.df + id_feature = [{'type': 'var_dropdown', 'index': 2}] + id_str_modality = [] + id_bool_modality = [{'type': 'dynamic-bool', 'index': 2}] + id_lower_modality = [] + id_date = [] + val_feature = ['column4'] + val_str_modality = [] + val_bool_modality = [True] + val_lower_modality = [] + val_upper_modality = [] + start_date = [] + end_date = [] + + expected_result = pd.DataFrame( + { + 'column1': [1, 3], + 'column2': ['a', 'c'], + 'column3': [1.1, 2.2], + 'column4': [True, True], + 'column5': [pd.Timestamp('2023-01-01'), pd.Timestamp('2023-01-03')], + }, + index=[0, 2] + ) + result = select_data_from_filters( + round_dataframe, + id_feature, + id_str_modality, + id_bool_modality, + id_lower_modality, + id_date, + val_feature, + val_str_modality, + val_bool_modality, + val_lower_modality, + val_upper_modality, + start_date, + end_date, + ) + pd.testing.assert_frame_equal(expected_result, result) + + def test_select_data_from_filters_date(self): + round_dataframe = self.df + id_feature = [{'type': 'var_dropdown', 'index': 1}] + id_str_modality = [] + id_bool_modality = [] + id_lower_modality = [] + id_date = [{'type': 'dynamic-date', 'index': 1}] + val_feature = ['column5'] + val_str_modality = [] + val_bool_modality = [] + val_lower_modality = [] + val_upper_modality = [] + start_date = [pd.Timestamp('2023-01-01')] + end_date = [pd.Timestamp('2023-01-03')] + + expected_result = pd.DataFrame( + { + 'column1': [1, 2, 3], + 'column2': ['a', 'b', 'c'], + 'column3': [1.1, 3.3, 2.2], + 'column4': [True, False, True], + 'column5': [pd.Timestamp('2023-01-01'), pd.Timestamp('2023-01-02'), pd.Timestamp('2023-01-03')], + }, + index=[0, 1, 2] + ) + result = select_data_from_filters( + round_dataframe, + id_feature, + id_str_modality, + id_bool_modality, + id_lower_modality, + id_date, + val_feature, + val_str_modality, + val_bool_modality, + val_lower_modality, + val_upper_modality, + start_date, + end_date, + ) + pd.testing.assert_frame_equal(expected_result, result) + + def test_select_data_from_filters_numeric(self): + round_dataframe = self.df + id_feature = [{'type': 'var_dropdown', 'index': 1}, {'type': 'var_dropdown', 'index': 2}] + id_str_modality = [] + id_bool_modality = [] + id_lower_modality = [{'type': 'lower', 'index': 1}, {'type': 'lower', 'index': 2}] + id_date = [] + val_feature = ['column1', 'column3'] + val_str_modality = [] + val_bool_modality = [] + val_lower_modality = [0, 0] + val_upper_modality = [3, 3] + start_date = [] + end_date = [] + + expected_result = pd.DataFrame( + { + 'column1': [1, 3], + 'column2': ['a', 'c'], + 'column3': [1.1, 2.2], + 'column4': [True, True], + 'column5': [pd.Timestamp('2023-01-01'), pd.Timestamp('2023-01-03')], + }, + index=[0, 2] + ) + result = select_data_from_filters( + round_dataframe, + id_feature, + id_str_modality, + id_bool_modality, + id_lower_modality, + id_date, + val_feature, + val_str_modality, + val_bool_modality, + val_lower_modality, + val_upper_modality, + start_date, + end_date, + ) + pd.testing.assert_frame_equal(expected_result, result) + + def test_select_data_from_filters_multi_types(self): + round_dataframe = self.df + id_feature = [{'type': 'var_dropdown', 'index': 1}, {'type': 'var_dropdown', 'index': 2}] + id_str_modality = [{'type': 'dynamic-str', 'index': 2}] + id_bool_modality = [] + id_lower_modality = [{'type': 'lower', 'index': 1}] + id_date = [] + val_feature = ['column1', 'column2'] + val_str_modality = [['a', 'c', 'd', 'e']] + val_bool_modality = [] + val_lower_modality = [0] + val_upper_modality = [3] + start_date = [] + end_date = [] + + expected_result = pd.DataFrame( + { + 'column1': [1, 3], + 'column2': ['a', 'c'], + 'column3': [1.1, 2.2], + 'column4': [True, True], + 'column5': [pd.Timestamp('2023-01-01'), pd.Timestamp('2023-01-03')], + }, + index=[0, 2] + ) + result = select_data_from_filters( + round_dataframe, + id_feature, + id_str_modality, + id_bool_modality, + id_lower_modality, + id_date, + val_feature, + val_str_modality, + val_bool_modality, + val_lower_modality, + val_upper_modality, + start_date, + end_date, + ) + pd.testing.assert_frame_equal(expected_result, result) \ No newline at end of file From ac88a71925d03f0a145970ad145c339576bde512 Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Mon, 26 Jun 2023 17:37:32 +0200 Subject: [PATCH 08/12] Refacto feature importance and selector. --- shapash/webapp/smart_app.py | 55 ++---- shapash/webapp/utils/callbacks.py | 178 +++++++++++++++++- .../unit_tests/webapp/utils/test_callbacks.py | 98 +++++++++- 3 files changed, 291 insertions(+), 40 deletions(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index 029914f4..706fc017 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -27,6 +27,12 @@ from shapash.webapp.utils.callbacks import ( select_data_from_prediction_picking, select_data_from_filters, + get_feature_from_clicked_data, + get_feature_from_features_groups, + get_group_name, + get_indexes_from_datatable, + update_click_data_on_subset_changes, + get_figure_zoom ) @@ -1562,11 +1568,7 @@ def update_feature_importance(label, """ ctx = dash.callback_context # Zoom is False by Default. It becomes True if we click on it - click = 2 if click_zoom is None else click_zoom - if click % 2 == 0: - zoom_active = False - else: - zoom_active = True + zoom_active = get_figure_zoom(click_zoom) selection = None list_index = self.list_index if clickData is not None and (ctx.triggered[0]['prop_id'] in [ @@ -1575,35 +1577,26 @@ def update_feature_importance(label, 'dataset.data' ] or ('del_dropdown_button' in ctx.triggered[0]['prop_id'] and None not in nclicks_del)): - point = clickData['points'][0] - point['curveNumber'] = 0 - clickData = {'points':[point]} + clickData = update_click_data_on_subset_changes(clickData) selected_feature = self.explainer.inv_features_dict.get( - clickData['points'][0]['label'].replace('', '').replace('', '') + get_feature_from_clicked_data(clickData) ) if clickData else None if self.explainer.features_groups and bool_group: - list_sub_features = [f for group_features in self.explainer.features_groups.values() - for f in group_features] if ctx.triggered[0]['prop_id'] == 'card_global_feature_importance.n_clicks': # When we click twice on the same bar this will reset the graph if clickData_store == clickData: selected_feature = None - if selected_feature in list_sub_features: - for k, v in self.explainer.features_groups.items(): - if selected_feature in v: - selected_feature = k + selected_feature = get_feature_from_features_groups(selected_feature, self.explainer.features_groups) + elif ctx.triggered[0]['prop_id'] == 'ember_global_feature_importance.n_clicks': + selected_feature = get_feature_from_features_groups(selected_feature, self.explainer.features_groups) else: selected_feature = None + selection = get_indexes_from_datatable(data, list_index) - selection = [d['_index_'] for d in data] - if len(selection)==len(list_index) or len(selection)==0: - selection=None - - group_name = selected_feature if (self.explainer.features_groups is not None - and selected_feature in self.explainer.features_groups.keys()) else None + group_name = get_group_name(selected_feature, self.explainer.features_groups) figure = self.explainer.plot.features_importance( max_features=features, @@ -1666,25 +1659,19 @@ def update_feature_selector(feature, fs_figure: feature selector graph """ # Zoom is False by Default. It becomes True if we click on it - click = 2 if click_zoom is None else click_zoom - if click % 2 == 0: - zoom_active = False - else: - zoom_active = True # To check if zoom is activated + zoom_active = get_figure_zoom(click_zoom) subset = None list_index = self.list_index if feature is not None: - selected_feature = feature['points'][0]['label'].replace('', '').replace('', '') + selected_feature = get_feature_from_clicked_data(feature) else: selected_feature = self.selected_feature if feature is not None and feature['points'][0]['curveNumber'] == 0 and \ len(gfi_figure['data']) == 2: - subset = [d['_index_'] for d in data] + subset = get_indexes_from_datatable(data, list_index) else: subset = None - if subset is not None and (len(subset)==len(list_index) or len(subset)==0): - subset = None fs_figure = self.explainer.plot.contribution_plot( col=selected_feature, @@ -1883,11 +1870,7 @@ def update_detail_feature(threshold, detail feature graph """ # Zoom is False by Default. It becomes True if we click on it - click = 2 if click_zoom is None else click_zoom - if click % 2 == 0: - zoom_active = False - else: - zoom_active = True + zoom_active = get_figure_zoom(click_zoom) ctx = dash.callback_context selected = None if ctx.triggered[0]['prop_id'] == 'feature_selector.clickData': @@ -2174,7 +2157,7 @@ def update_prediction_picking(data, if selectedData is not None and len(selectedData['points'])>0: raise PreventUpdate else: - subset = [d['_index_'] for d in data] + subset = get_indexes_from_datatable(data) figure = self.explainer.plot.scatter_plot_prediction( selection=subset, diff --git a/shapash/webapp/utils/callbacks.py b/shapash/webapp/utils/callbacks.py index 5444b3ae..fed145cc 100644 --- a/shapash/webapp/utils/callbacks.py +++ b/shapash/webapp/utils/callbacks.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional import numpy as np import pandas as pd @@ -6,6 +6,20 @@ def select_data_from_prediction_picking(round_dataframe: pd.DataFrame, selected_data: dict) -> pd.DataFrame: + """Create a subset dataframe from the prediction picking graph selection. + + Parameters + ---------- + round_dataframe : pd.DataFrame + Data to sample + selected_data : dict + Selected sample in the prediction picking graph + + Returns + ------- + pd.DataFrame + Subset dataframe + """ row_ids = [] for p in selected_data['points']: row_ids.append(p['customdata']) @@ -29,6 +43,42 @@ def select_data_from_filters( start_date: list, end_date: list, ) -> pd.DataFrame: + """Create a subset dataframe from filters. + + Parameters + ---------- + round_dataframe : pd.DataFrame + Data to sample + id_feature : list + features ids + id_str_modality : list + string features ids + id_bool_modality : list + boolean features ids + id_lower_modality : list + numeric features ids + id_date : list + date features ids + val_feature : list + features names + val_str_modality : list + string modalities selected + val_bool_modality : list + boolean modalities selected + val_lower_modality : list + lower values of numeric filter + val_upper_modality : list + upper values of numeric filter + start_date : list + start dates selected + end_date : list + end dates selected + + Returns + ------- + pd.DataFrame + Subset dataframe + """ # get list of ID feature_id = [id_feature[i]['index'] for i in range(len(id_feature))] str_id = [id_str_modality[i]['index'] for i in range(len(id_str_modality))] @@ -65,4 +115,128 @@ def select_data_from_filters( df = df[(df[val_feature[i]] >= val_lower_modality[position]) & (df[val_feature[i]] <= val_upper_modality[position])] - return df \ No newline at end of file + return df + + +def get_feature_from_clicked_data(click_data: dict) -> str: + """Get the feature name from the feature importance graph click data. + + Parameters + ---------- + click_data : dict + Feature importance graph click data + + Returns + ------- + str + Selected feature + """ + selected_feature = click_data['points'][0]['label'].replace('', '').replace('', '') + return selected_feature + + +def get_feature_from_features_groups(selected_feature: Optional[str], features_groups: dict) -> Optional[str]: + """Get the group feature name of the selected feature. + + Parameters + ---------- + selected_feature : Optional[str] + Selected feature + features_groups : dict + Groups names and corresponding list of features + + Returns + ------- + Optional[str] + Group feature name or selected feature if not in a group. + """ + list_sub_features = [f for group_features in features_groups.values() + for f in group_features] + if selected_feature in list_sub_features: + for k, v in features_groups.items(): + if selected_feature in v: + selected_feature = k + return selected_feature + + +def get_group_name(selected_feature: Optional[str], features_groups: Optional[dict]) -> Optional[str]: + """Get the group feature name if the selected feature is one of the groups. + + Parameters + ---------- + selected_feature : Optional[str] + Selected feature + features_groups : Optional[dict] + Groups names and corresponding list of features + + Returns + ------- + Optional[str] + Group feature name + """ + group_name = selected_feature if ( + features_groups is not None and selected_feature in features_groups.keys() + ) else None + return group_name + + +def get_indexes_from_datatable(data: list, list_index: Optional[list] = None) -> Optional[list]: + """Get the indexes of the data. If list_index is given and is the same length than + the indexes, there is no subset selected. + + Parameters + ---------- + data : list + Data from the table + list_index : Optional[list], optional + Default index list to compare the subset with, by default None + + Returns + ------- + Optional[list] + Indexes of the data + """ + indexes = [d['_index_'] for d in data] + if list_index is not None and (len(indexes)==len(list_index) or len(indexes)==0): + indexes = None + return indexes + + +def update_click_data_on_subset_changes(click_data: dict) -> dict: + """Update click data on subset changes to always correspond to the feature selector graph. + + Parameters + ---------- + click_data : dict + Feature importance click data + + Returns + ------- + dict + Updated feature importance click data + """ + point = click_data['points'][0] + point['curveNumber'] = 0 + click_data = {'points':[point]} + return click_data + + +def get_figure_zoom(click_zoom: int) -> bool : + """Get figure zoom from n_clicks + + Parameters + ---------- + click_zoom : int + Number of clicks on zoom button + + Returns + ------- + bool + zoom active or not + """ + click = 2 if click_zoom is None else click_zoom + if click % 2 == 0: + zoom_active = False + else: + zoom_active = True + return zoom_active diff --git a/tests/unit_tests/webapp/utils/test_callbacks.py b/tests/unit_tests/webapp/utils/test_callbacks.py index 066fc299..700398a4 100644 --- a/tests/unit_tests/webapp/utils/test_callbacks.py +++ b/tests/unit_tests/webapp/utils/test_callbacks.py @@ -4,7 +4,16 @@ from sklearn.tree import DecisionTreeClassifier from shapash import SmartExplainer from shapash.webapp.smart_app import SmartApp -from shapash.webapp.utils.callbacks import select_data_from_prediction_picking, select_data_from_filters +from shapash.webapp.utils.callbacks import ( + select_data_from_prediction_picking, + select_data_from_filters, + get_feature_from_clicked_data, + get_feature_from_features_groups, + get_group_name, + get_indexes_from_datatable, + update_click_data_on_subset_changes, + get_figure_zoom +) class TestCallbacks(unittest.TestCase): @@ -36,6 +45,24 @@ def __init__(self, *args, **kwargs): additional_features_dict=additional_features_dict ) self.smart_app = SmartApp(self.xpl) + + self.click_data = { + 'points': [ + { + 'curveNumber': 0, + 'pointNumber': 3, + 'pointIndex': 3, + 'x': 0.4649, + 'y': 'Sex', + 'label': 'Sex', + 'value': 0.4649, + 'customdata': 'Sex', + 'marker.color': 'rgba(244, 192, 0, 1.0)', + 'bbox': {'x0': 717.3, 'x1': 717.3, 'y0': 82.97, 'y1': 130.78} + } + ] + } + super(TestCallbacks, self).__init__(*args, **kwargs) def test_default_init_data(self): @@ -279,4 +306,71 @@ def test_select_data_from_filters_multi_types(self): start_date, end_date, ) - pd.testing.assert_frame_equal(expected_result, result) \ No newline at end of file + pd.testing.assert_frame_equal(expected_result, result) + + def test_get_feature_from_clicked_data(self): + feature = get_feature_from_clicked_data(self.click_data) + assert feature == "Sex" + + def test_get_feature_from_features_groups(self): + features_groups = {"A": ["column1", "column3"]} + feature = get_feature_from_features_groups("column3", features_groups) + assert feature == "A" + + feature = get_feature_from_features_groups("A", features_groups) + assert feature == "A" + + feature = get_feature_from_features_groups("column2", features_groups) + assert feature == "column2" + + def test_get_group_name(self): + features_groups = {"A": ["column1", "column3"]} + feature = get_group_name("A", features_groups) + assert feature == "A" + + feature = get_group_name("column3", features_groups) + assert feature == None + + def test_get_indexes_from_datatable(self): + data = self.smart_app.components['table']['dataset'].data + subset = get_indexes_from_datatable(data) + assert subset == [0, 1, 2, 3, 4] + + def test_get_indexes_from_datatable_no_subset(self): + data = self.smart_app.components['table']['dataset'].data + subset = get_indexes_from_datatable(data, [0, 1, 2, 3, 4]) + assert subset == None + + def test_get_indexes_from_datatable_empty(self): + subset = get_indexes_from_datatable([], [0, 1, 2, 3, 4]) + assert subset == None + + def test_update_click_data_on_subset_changes(self): + click_data = { + 'points': [ + { + 'curveNumber': 1, + 'pointNumber': 3, + 'pointIndex': 3, + 'x': 0.4649, + 'y': 'Sex', + 'label': 'Sex', + 'value': 0.4649, + 'customdata': 'Sex', + 'marker.color': 'rgba(244, 192, 0, 1.0)', + 'bbox': {'x0': 717.3, 'x1': 717.3, 'y0': 82.97, 'y1': 130.78} + } + ] + } + click_data = update_click_data_on_subset_changes(click_data) + assert click_data == self.click_data + + def test_get_figure_zoom(self): + zoom_active = get_figure_zoom(None) + assert zoom_active == False + + zoom_active = get_figure_zoom(1) + assert zoom_active == True + + zoom_active = get_figure_zoom(4) + assert zoom_active == False From b1bef6e7194fd65342479f7c7ca392f2173d8cab Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Wed, 28 Jun 2023 14:39:28 +0200 Subject: [PATCH 09/12] Refacto index, detail_feature, id card. --- shapash/webapp/smart_app.py | 152 ++++------- shapash/webapp/utils/callbacks.py | 237 ++++++++++++++++-- .../unit_tests/webapp/utils/test_callbacks.py | 106 +++++++- 3 files changed, 375 insertions(+), 120 deletions(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index 706fc017..01e6c753 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -32,7 +32,13 @@ get_group_name, get_indexes_from_datatable, update_click_data_on_subset_changes, - get_figure_zoom + get_figure_zoom, + get_feature_contributions_sign_to_show, + update_features_to_display, + get_id_card_features, + get_id_card_contrib, + create_id_card_data, + create_id_card_layout ) @@ -1736,22 +1742,15 @@ def update_index_id(click_data, """ ctx = dash.callback_context selected = None - if ctx.triggered[0]['prop_id'] != 'dataset.data': - if ctx.triggered[0]['prop_id'] == 'feature_selector.clickData': - selected = click_data['points'][0]['customdata'][1] - elif ctx.triggered[0]['prop_id'] == 'prediction_picking.clickData': - selected = prediction_picking['points'][0]['customdata'] - elif ctx.triggered[0]['prop_id'] == 'dataset.active_cell': - if cell is not None: - selected = data[cell['row']]['_index_'] - else: - # Get actual value in field to refresh the selected value - selected = current_index_id - elif (('del_dropdown_button' in ctx.triggered[0]['prop_id']) & - (None in nclicks_del)): - selected = current_index_id - else: - raise PreventUpdate + if ctx.triggered[0]['prop_id'] == 'feature_selector.clickData': + selected = click_data['points'][0]['customdata'][1] + elif ctx.triggered[0]['prop_id'] == 'prediction_picking.clickData': + selected = prediction_picking['points'][0]['customdata'] + elif ctx.triggered[0]['prop_id'] == 'dataset.active_cell': + selected = data[cell['row']]['_index_'] + elif (('del_dropdown_button' in ctx.triggered[0]['prop_id']) & + (None in nclicks_del)): + selected = current_index_id return selected, True @app.callback( @@ -1792,25 +1791,11 @@ def update_max_contrib_id(is_open, if is_open: raise PreventUpdate else: - max = min(features, len(self.explainer.x_init.columns)) - if max // 5 == max / 5: - nb_marks = min(int(max // 5), 10) - elif max // 4 == max / 4: - nb_marks = min(int(max // 4), 10) - elif max // 3 == max / 3: - nb_marks = min(int(max // 3), 10) - elif max // 7 == max / 7: - nb_marks = min(int(max // 6), 10) - else: - nb_marks = 2 - marks = {f'{round(max * feat / nb_marks)}': f'{round(max * feat / nb_marks)}' - for feat in range(1, nb_marks + 1)} - marks['1'] = '1' - if max < value: - value = max - else: - value = no_update - + value, max, marks = update_features_to_display( + features, + len(self.explainer.x_init.columns), + value + ) return value, max, marks @app.callback( @@ -1822,9 +1807,6 @@ def update_max_contrib_id(is_open, Input('check_id_negative', 'value'), Input('masked_contrib_id', 'value'), Input('select_label', 'value'), - Input('dataset', 'active_cell'), - Input('feature_selector', 'clickData'), - Input('prediction_picking', 'clickData'), Input("validation", "n_clicks"), Input('bool_groups', 'on'), Input('ember_detail_feature', 'n_clicks'), @@ -1840,9 +1822,6 @@ def update_detail_feature(threshold, negative, masked, label, - cell, - click_data, - prediction_picking, validation_click, bool_group, click_zoom, @@ -1871,29 +1850,11 @@ def update_detail_feature(threshold, """ # Zoom is False by Default. It becomes True if we click on it zoom_active = get_figure_zoom(click_zoom) - ctx = dash.callback_context - selected = None - if ctx.triggered[0]['prop_id'] == 'feature_selector.clickData': - selected = click_data['points'][0]['customdata'][1] - elif ctx.triggered[0]['prop_id'] == 'prediction_picking.clickData': - selected = prediction_picking['points'][0]['customdata'] - elif ctx.triggered[0]['prop_id'] in ['threshold_id.value', 'validation.n_clicks']: - selected = index - elif ctx.triggered[0]['prop_id'] == 'dataset.active_cell': - if cell: - selected = data[cell['row']]['_index_'] - else: - zoom_active = zoom_active - # raise PreventUpdate - else: - selected = index + selected = index if check_row(data, selected) is None: selected = None threshold = threshold if threshold != 0 else None - if positive == [1]: - sign = (None if negative == [1] else True) - else: - sign = (False if negative == [1] else None) + sign = get_feature_contributions_sign_to_show(positive, negative) self.explainer.filter(threshold=threshold, features_to_hide=masked, positive=sign, @@ -1970,54 +1931,43 @@ def update_id_card(n_submit, label, sort_by, order, data, index): selected = check_row(data, index) title_contrib = "Contribution" if n_submit and selected is not None: - selected_row = pd.DataFrame([data[selected]], index=["feature_value"]).T - selected_row["feature_name"] = selected_row.index.map( - lambda x: x if x in self.special_cols else self.features_dict[x] - ) + selected_row = get_id_card_features(data, selected, self.special_cols, self.features_dict) if self.explainer._case == 'classification': if label is None: label = -1 label_num, _, label_value = self.explainer.check_label_name(label) - contrib = self.explainer.data['contrib_sorted'][label_num].loc[index, :].values - var_dict = self.explainer.data['var_dict'][label_num].loc[index, :].values + selected_contrib = get_id_card_contrib( + self.explainer.data, + index, + self.explainer.features_dict, + self.explainer.columns_dict, + label_num + ) proba = self.explainer.plot.local_pred(index, label_num) title_contrib = f"Contribution: {label_value} ({proba.round(2):.2f})" - _, _, predicted_label_value = self.explainer.check_label_name(selected_row.loc["_predict_", "feature_value"]) + _, _, predicted_label_value = self.explainer.check_label_name( + selected_row.loc["_predict_", "feature_value"] + ) selected_row.loc["_predict_", "feature_value"] = predicted_label_value else: - contrib = self.explainer.data['contrib_sorted'].loc[index, :].values - var_dict = self.explainer.data['var_dict'].loc[index, :].values - var_dict = [self.explainer.features_dict[self.explainer.columns_dict[x]] for x in var_dict] - selected_contrib = pd.DataFrame([var_dict, contrib], index=["feature_name", "feature_contrib"]).T - selected_contrib["feature_contrib"] = selected_contrib["feature_contrib"].apply(lambda x: round(x, 4)) - selected_data = selected_row.merge(selected_contrib, how="left", on="feature_name") - selected_data.index = selected_row.index - selected_data = pd.concat([ - selected_data.loc[self.special_cols], - selected_data.drop(index=self.special_cols+list(self.explainer.additional_features_dict.keys())).sort_values(sort_by, ascending=order), - selected_data.loc[list(self.explainer.additional_features_dict.keys())].sort_values(sort_by, ascending=order) - ]) - children = [] - for _, row in selected_data.iterrows(): - label_style = { - 'fontWeight': 'bold', - 'font-style': 'italic' - } if row["feature_name"] in self.explainer.additional_features_dict.values() else {'fontWeight': 'bold'} - children.append( - dbc.Row([ - dbc.Col(dbc.Label(row["feature_name"]), width=3, style=label_style), - dbc.Col(dbc.Label(row["feature_value"]), width=5, className="id_card_solid"), - dbc.Col(width=1), - dbc.Col( - dbc.Row( - dbc.Label(format(row["feature_contrib"], '.4f'), width="auto", style={"padding-top":0}), - justify="end" - ), - width=2, - className="id_card_solid", - ) if row["feature_contrib"]==row["feature_contrib"] else None, - ]) + selected_contrib = get_id_card_contrib( + self.explainer.data, + index, + self.explainer.features_dict, + self.explainer.columns_dict, ) + + selected_data = create_id_card_data( + selected_row, + selected_contrib, + sort_by, + order, + self.special_cols, + self.explainer.additional_features_dict + ) + + children = create_id_card_layout(selected_data, self.explainer.additional_features_dict) + return {"display":"flex", "margin-left":"auto", "margin-right":0}, children, title_contrib else: return {"display":"none"}, [], title_contrib diff --git a/shapash/webapp/utils/callbacks.py b/shapash/webapp/utils/callbacks.py index fed145cc..000e3715 100644 --- a/shapash/webapp/utils/callbacks.py +++ b/shapash/webapp/utils/callbacks.py @@ -1,8 +1,8 @@ -from typing import Optional +from typing import Optional, Tuple import numpy as np import pandas as pd - +import dash_bootstrap_components as dbc def select_data_from_prediction_picking(round_dataframe: pd.DataFrame, selected_data: dict) -> pd.DataFrame: @@ -29,19 +29,19 @@ def select_data_from_prediction_picking(round_dataframe: pd.DataFrame, selected_ def select_data_from_filters( - round_dataframe: pd.DataFrame, - id_feature: list, - id_str_modality: list, - id_bool_modality: list, - id_lower_modality: list, - id_date: list, - val_feature: list, - val_str_modality: list, - val_bool_modality: list, - val_lower_modality: list, - val_upper_modality: list, - start_date: list, - end_date: list, + round_dataframe: pd.DataFrame, + id_feature: list, + id_str_modality: list, + id_bool_modality: list, + id_lower_modality: list, + id_date: list, + val_feature: list, + val_str_modality: list, + val_bool_modality: list, + val_lower_modality: list, + val_upper_modality: list, + start_date: list, + end_date: list, ) -> pd.DataFrame: """Create a subset dataframe from filters. @@ -240,3 +240,210 @@ def get_figure_zoom(click_zoom: int) -> bool : else: zoom_active = True return zoom_active + + +def get_feature_contributions_sign_to_show(positive: list, negative: list) -> Optional[bool]: + """Get the feature contributions sign to show on plot. + + Parameters + ---------- + positive : list + Click on positive contributions + negative : list + Click on negative contributions + + Returns + ------- + Optional[bool] + Sign to show on plot + """ + if positive == [1]: + sign = (None if negative == [1] else True) + else: + sign = (False if negative == [1] else None) + return sign + + +def update_features_to_display(features: int, nb_columns: int, value: int) -> Tuple[int, int, dict]: + """Update features to display slider. + + Parameters + ---------- + features : int + Number of features to plot from the settings + nb_columns : int + Number of columns in the data + value : int + Number of columns to plot + + Returns + ------- + Tuple[int, int, dict] + Number of columns to plot, Number max of columns to plot, Marks in the slider + """ + max = min(features, nb_columns) + if max % 5 == 0: + nb_marks = min(int(max // 5), 10) + elif max % 4 == 0: + nb_marks = min(int(max // 4), 10) + elif max % 3 == 0: + nb_marks = min(int(max // 3), 10) + elif max % 7 == 0: + nb_marks = min(int(max // 7), 10) + else: + nb_marks = 2 + marks = {f'{round(max * feat / nb_marks)}': f'{round(max * feat / nb_marks)}' + for feat in range(1, nb_marks + 1)} + marks['1'] = '1' + if max < value: + value = max + + return value, max, marks + + +def get_id_card_features(data: list, selected: int, special_cols: list, features_dict: dict) -> pd.DataFrame: + """Get the features of the selected index for the identity card. + + Parameters + ---------- + data : list + Data from the table + selected : int + Row number of the selected index + special_cols : list + Sepcial columns about the index, the prediction... + features_dict : dict + Dictionary mapping technical feature names to domain names + + Returns + ------- + pd.DataFrame + Dataframe of the features + """ + selected_row = pd.DataFrame([data[selected]], index=["feature_value"]).T + selected_row["feature_name"] = selected_row.index.map( + lambda x: x if x in special_cols else features_dict[x] + ) + return selected_row + + +def get_id_card_contrib( + data: dict, + index: int, + features_dict: dict, + columns_dict: dict, + label_num: int = None +) -> pd.DataFrame: + """Get the contributions of the selected index for the identity card. + + Parameters + ---------- + data : dict + Data from the smart explainer + index : int + Index selected + features_dict : dict + Dictionary mapping technical feature names to domain names + columns_dict : dict + Dictionary mapping integer column number to technical feature names + label_num : int, optional + Label num, by default None + + Returns + ------- + pd.DataFrame + Dataframe of the contributions + """ + if label_num is not None: + contrib = data['contrib_sorted'][label_num].loc[index, :].values + var_dict = data['var_dict'][label_num].loc[index, :].values + else: + contrib = data['contrib_sorted'].loc[index, :].values + var_dict = data['var_dict'].loc[index, :].values + + var_dict = [features_dict[columns_dict[x]] for x in var_dict] + selected_contrib = pd.DataFrame([var_dict, contrib], index=["feature_name", "feature_contrib"]).T + selected_contrib["feature_contrib"] = selected_contrib["feature_contrib"].apply(lambda x: round(x, 4)) + + return selected_contrib + + +def create_id_card_data( + selected_row: pd.DataFrame, + selected_contrib: pd.DataFrame, + sort_by: str, + order: bool, + special_cols: list, + additional_features_dict: dict +) -> pd.DataFrame: + """Merge and sort features and contributions dataframes for the identity card. + + Parameters + ---------- + selected_row : pd.DataFrame + Dataframe of the features + selected_contrib : pd.DataFrame + Dataframe of the contributions + sort_by : str + Column to sort by + order : bool + Ascending or descending order + special_cols : list + Sepcial columns about the index, the prediction... + additional_features_dict : dict + Dictionary mapping technical feature names to domain names for additional data + + Returns + ------- + pd.DataFrame + Dataframe of the data for the identity card + """ + selected_data = selected_row.merge(selected_contrib, how="left", on="feature_name") + selected_data.index = selected_row.index + selected_data = pd.concat([ + selected_data.loc[special_cols], + selected_data.drop(index=special_cols+list(additional_features_dict.keys())).sort_values(sort_by, ascending=order), + selected_data.loc[list(additional_features_dict.keys())].sort_values(sort_by, ascending=order) + ]) + + return selected_data + + +def create_id_card_layout(selected_data: pd.DataFrame, additional_features_dict: dict) -> list: + """Create the layout of the identity card + + Parameters + ---------- + selected_data : pd.DataFrame + Dataframe of the data for the identity card + additional_features_dict : dict + Dictionary mapping technical feature names to domain names for additional data + + Returns + ------- + list + Layout of the identity card + """ + children = [] + for _, row in selected_data.iterrows(): + label_style = { + 'fontWeight': 'bold', + 'font-style': 'italic' + } if row["feature_name"] in additional_features_dict.values() else {'fontWeight': 'bold'} + children.append( + dbc.Row([ + dbc.Col(dbc.Label(row["feature_name"]), width=3, style=label_style), + dbc.Col(dbc.Label(row["feature_value"]), width=5, className="id_card_solid"), + dbc.Col(width=1), + dbc.Col( + dbc.Row( + dbc.Label(format(row["feature_contrib"], '.4f'), width="auto", style={"padding-top":0}), + justify="end" + ), + width=2, + className="id_card_solid", + ) if row["feature_contrib"]==row["feature_contrib"] else None, + ]) + ) + + return children \ No newline at end of file diff --git a/tests/unit_tests/webapp/utils/test_callbacks.py b/tests/unit_tests/webapp/utils/test_callbacks.py index 700398a4..f3123e03 100644 --- a/tests/unit_tests/webapp/utils/test_callbacks.py +++ b/tests/unit_tests/webapp/utils/test_callbacks.py @@ -1,4 +1,5 @@ import unittest +import copy import numpy as np import pandas as pd from sklearn.tree import DecisionTreeClassifier @@ -12,7 +13,13 @@ get_group_name, get_indexes_from_datatable, update_click_data_on_subset_changes, - get_figure_zoom + get_figure_zoom, + get_feature_contributions_sign_to_show, + update_features_to_display, + get_id_card_features, + get_id_card_contrib, + create_id_card_data, + create_id_card_layout, ) @@ -31,7 +38,7 @@ def __init__(self, *args, **kwargs): self.df = df dataframe_x = df[['column1','column3']].copy() - y_target = pd.DataFrame(data=np.array([1, 2, 3, 4, 5]), columns=['pred']) + y_target = pd.DataFrame(data=np.array([0, 0, 0, 1, 1]), columns=['pred']) model = DecisionTreeClassifier().fit(dataframe_x, y_target) features_dict = {'column3': 'Useless col'} additional_data = df[['column2']].copy() @@ -62,6 +69,7 @@ def __init__(self, *args, **kwargs): } ] } + self.special_cols = ['_index_', '_predict_', '_target_'] super(TestCallbacks, self).__init__(*args, **kwargs) @@ -69,8 +77,8 @@ def test_default_init_data(self): expected_result = pd.DataFrame( { '_index_': [0, 1, 2, 3, 4], - '_predict_': [1, 2, 3, 4, 5], - '_target_': [1, 2, 3, 4, 5], + '_predict_': [0, 0, 0, 1, 1], + '_target_': [0, 0, 0, 1, 1], 'column1': [1, 2, 3, 4, 5], 'column3': [1.1, 3.3, 2.2, 4.4, 5.5], '_column2': ['a', 'b', 'c', 'd', 'e'], @@ -374,3 +382,93 @@ def test_get_figure_zoom(self): zoom_active = get_figure_zoom(4) assert zoom_active == False + + def test_get_feature_contributions_sign_to_show(self): + sign = get_feature_contributions_sign_to_show([1], [1]) + assert sign == None + + sign = get_feature_contributions_sign_to_show([1], []) + assert sign == True + + sign = get_feature_contributions_sign_to_show([], []) + assert sign == None + + sign = get_feature_contributions_sign_to_show([], [1]) + assert sign == False + + def test_update_features_to_display(self): + value, max, marks = update_features_to_display(20, 40, 22) + assert value==20 + assert max==20 + assert marks=={'1': '1', '5': '5', '10': '10', '15': '15', '20': '20'} + + value, max, marks = update_features_to_display(7, 40, 6) + assert value==6 + assert max==7 + assert marks=={'1': '1', '7': '7'} + + def test_get_id_card_features(self): + data = self.smart_app.components['table']['dataset'].data + features_dict = copy.deepcopy(self.xpl.features_dict) + features_dict.update(self.xpl.additional_features_dict) + selected_row = get_id_card_features(data, 3, self.special_cols, features_dict) + expected_result = pd.DataFrame( + { + 'feature_value': [3, 1, 1, 4, 4.4, 'd'], + 'feature_name': ['_index_', '_predict_', '_target_', 'column1', 'Useless col', '_Additional col'], + }, + index = ['_index_', '_predict_', '_target_', 'column1', 'column3', '_column2'] + ) + pd.testing.assert_frame_equal(selected_row, expected_result) + + def test_get_id_card_contrib(self): + data = self.xpl.data + selected_contrib = get_id_card_contrib(data, 3, self.xpl.features_dict, self.xpl.columns_dict, 0) + assert set(selected_contrib['feature_name']) == {'Useless col', 'column1'} + assert selected_contrib.columns.tolist() == ['feature_name', 'feature_contrib'] + + def test_create_id_card_data(self): + selected_row = pd.DataFrame( + { + 'feature_value': [3, 1, 1, 4, 4.4, 'd'], + 'feature_name': ['_index_', '_predict_', '_target_', 'column1', 'Useless col', '_Additional col'], + }, + index = ['_index_', '_predict_', '_target_', 'column1', 'column3', '_column2'] + ) + + selected_contrib = pd.DataFrame( + { + 'feature_name': ['column1', 'Useless col'], + 'feature_contrib': [-0.6, 0], + } + ) + + selected_data = create_id_card_data( + selected_row, + selected_contrib, + 'feature_name', + True, + self.special_cols, + self.xpl.additional_features_dict + ) + expected_result = pd.DataFrame( + { + 'feature_value': [3, 1, 1, 4.4, 4, 'd'], + 'feature_name': ['_index_', '_predict_', '_target_', 'Useless col', 'column1', '_Additional col'], + 'feature_contrib': [np.nan, np.nan, np.nan, 0.0, -0.6, np.nan] + }, + index = ['_index_', '_predict_', '_target_', 'column3', 'column1', '_column2'] + ) + pd.testing.assert_frame_equal(selected_data, expected_result) + + def test_create_id_card_layout(self): + selected_data = pd.DataFrame( + { + 'feature_value': [3, 1, 1, 4.4, 4, 'd'], + 'feature_name': ['_index_', '_predict_', '_target_', 'Useless col', 'column1', '_Additional col'], + 'feature_contrib': [np.nan, np.nan, np.nan, 0.0, -0.6, np.nan] + }, + index = ['_index_', '_predict_', '_target_', 'column3', 'column1', '_column2'] + ) + children = create_id_card_layout(selected_data, self.xpl.additional_features_dict) + assert len(children)==6 \ No newline at end of file From 5da90175f68d3c621d14d9f4efe38ecaa50160b5 Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Fri, 30 Jun 2023 17:46:17 +0200 Subject: [PATCH 10/12] Refacto filter callbacks and others. --- shapash/webapp/smart_app.py | 129 +------------ shapash/webapp/utils/callbacks.py | 182 +++++++++++++++++- .../unit_tests/webapp/utils/test_callbacks.py | 18 +- 3 files changed, 207 insertions(+), 122 deletions(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index 01e6c753..dcca5861 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -38,7 +38,10 @@ get_id_card_features, get_id_card_contrib, create_id_card_data, - create_id_card_layout + create_id_card_layout, + get_feature_filter_options, + create_dropdown_feature_filter, + create_filter_modalities_selection ) @@ -2275,66 +2278,11 @@ def layout_filter(n_clicks_add, raise dash.exceptions.PreventUpdate button_id = ctx.triggered[0]['prop_id'].split('.')[0] - # We use domain name for feature name - dict_name = [self.features_dict[i] - for i in self.dataframe.drop(self.special_cols, axis=1).columns] - dict_id = [i for i in self.dataframe.drop(self.special_cols, axis=1).columns] - # Create dataframe to sort it by feature_name - df_feature_name = pd.DataFrame({'feature_name': dict_name, - 'feature_id': dict_id}) - df_feature_name = df_feature_name.sort_values( - by='feature_name').reset_index(drop=True) - # Options are sorted by feature_name - options = [{"label": i, "value": i} for i in self.special_cols] + \ - [{"label": df_feature_name.loc[i, 'feature_name'], - "value": df_feature_name.loc[i, 'feature_id']} - for i in range(len(df_feature_name))] + options = get_feature_filter_options(self.dataframe, self.features_dict, self.special_cols) # Creation of a new graph if button_id == 'add_dropdown_button': - # ID index definition - if n_clicks_add is None: - index_id = 0 - else: - index_id = n_clicks_add - # Appending a dropdown block to 'dropdowns_container'children - subset_filter = html.Div( - id={'type': 'bloc_div', - 'index': index_id}, - children=[ - html.Div([ - html.Br(), - # div which will contains label - html.Div( - id={'type': 'dynamic-output-label', - 'index': index_id}, - ) - ]), - html.Div([ - # div with dopdown button to select feature to filter - html.Div(dcc.Dropdown( - id={'type': 'var_dropdown', - 'index': index_id}, - options=options, - placeholder="Variable" - ), style={"width": "30%"}), - # div which will contains modalities - html.Div( - id={'type': 'dynamic-output', - 'index': index_id}, - style={"width": "50%"} - ), - # Button to delete bloc - dbc.Button( - id={'type': 'del_dropdown_button', - 'index': index_id}, - children='X', - color='warning', - size='sm' - ) - ], style={'display': 'flex'}) - ] - ) + subset_filter = create_dropdown_feature_filter(n_clicks_add, options) return currents_filters + [subset_filter] # Removal of all existing filters elif button_id == 'reset_dropdown_button': @@ -2455,73 +2403,14 @@ def display_output(value, # Context and init handling (no action) ctx = dash.callback_context if not ctx.triggered: - raise dash.exceptions.PreventUpdate + raise PreventUpdate # No update last modalities values if we click on add button if ctx.triggered[0]['prop_id'] == 'add_dropdown_button.n_clicks': - raise dash.exceptions.PreventUpdate + raise PreventUpdate # Creation on modalities dropdown button else: if value is not None: - if type(self.round_dataframe[value].iloc[0]) == bool: - new_element = html.Div(dcc.RadioItems( - [{'label': val, 'value': val} for - val in self.round_dataframe[value].unique()], - id={'type': 'dynamic-bool', - 'index': id['index']}, - value=self.round_dataframe[value].iloc[0], - inline=False - ), style={"width": "65%", 'margin-left': '20px'}) - elif (type(self.round_dataframe[value].iloc[0]) == str) | \ - ((type(self.round_dataframe[value].iloc[0]) == np.int64) & - (len(self.round_dataframe[value].unique()) <= 20)): - new_element = html.Div(dcc.Dropdown( - id={ - 'type': 'dynamic-str', - 'index': id['index'] - }, - options=[{'label': i, 'value': i} for - i in np.sort(self.round_dataframe[value].unique())], - multi=True, - ), style={"width": "65%", 'margin-left': '20px'}) - elif ((type(self.round_dataframe[value].iloc[0]) is pd.Timestamp) | - (type(self.round_dataframe[value].iloc[0]) is datetime.datetime)): - new_element = html.Div( - dcc.DatePickerRange( - id={ - 'type': 'dynamic-date', - 'index': id['index'] - }, - min_date_allowed=self.round_dataframe[value].min(), - max_date_allowed=self.round_dataframe[value].max(), - start_date=self.round_dataframe[value].min(), - end_date=self.round_dataframe[value].max() - ), style={'width': '65%', 'margin-left': '20px'}), - else: - lower_value = 0 - upper_value = 0 - new_element = html.Div([ - dcc.Input( - id={ - 'type': 'lower', - 'index': id['index'] - }, - value=lower_value, - type="number", - style={'width': '60px'}), - ' <= {} in [{}, {}]<= '.format( - value, - self.round_dataframe[value].min(), - self.round_dataframe[value].max()), - dcc.Input( - id={ - 'type': 'upper', - 'index': id['index'] - }, - value=upper_value, - type="number", - style={'width': '60px'} - ) - ], style={'margin-left': '20px'}) + new_element = create_filter_modalities_selection(value, id, self.round_dataframe) else: new_element = html.Div() return new_element diff --git a/shapash/webapp/utils/callbacks.py b/shapash/webapp/utils/callbacks.py index 000e3715..79d8103c 100644 --- a/shapash/webapp/utils/callbacks.py +++ b/shapash/webapp/utils/callbacks.py @@ -2,6 +2,8 @@ import numpy as np import pandas as pd +import datetime +from dash import dcc, html import dash_bootstrap_components as dbc @@ -446,4 +448,182 @@ def create_id_card_layout(selected_data: pd.DataFrame, additional_features_dict: ]) ) - return children \ No newline at end of file + return children + + +def get_feature_filter_options(dataframe: pd.DataFrame, features_dict: dict, special_cols: list) -> list: + """Get the columns names options for the filter. + + Parameters + ---------- + dataframe : pd.DataFrame + Dataframe + features_dict : dict + Dictionary mapping technical feature names to domain names + special_cols : list + Sepcial columns about the index, the prediction... + + Returns + ------- + list + Options for the filter + """ + # We use domain name for feature name + dict_name = [features_dict[i] + for i in dataframe.drop(special_cols, axis=1).columns] + dict_id = [i for i in dataframe.drop(special_cols, axis=1).columns] + # Create dataframe to sort it by feature_name + df_feature_name = pd.DataFrame({'feature_name': dict_name, + 'feature_id': dict_id}) + df_feature_name = df_feature_name.sort_values( + by='feature_name').reset_index(drop=True) + # Options are sorted by feature_name + options = [{"label": i, "value": i} for i in special_cols] + \ + [{"label": df_feature_name.loc[i, 'feature_name'], + "value": df_feature_name.loc[i, 'feature_id']} + for i in range(len(df_feature_name))] + + return options + + +def create_dropdown_feature_filter(n_clicks_add: Optional[int], options: list) -> html.Div: + """Create a new dropdown for the filter feature selection. + + Parameters + ---------- + n_clicks_add : Optional[int] + Number of clicks on the add filter button + options : list + Options for the selection + + Returns + ------- + html.Div + Div containing the dropdown + """ + # ID index definition + if n_clicks_add is None: + index_id = 0 + else: + index_id = n_clicks_add + # Appending a dropdown block to 'dropdowns_container'children + subset_filter = html.Div( + id={'type': 'bloc_div', + 'index': index_id}, + children=[ + html.Div([ + html.Br(), + # div which will contains label + html.Div( + id={'type': 'dynamic-output-label', + 'index': index_id}, + ) + ]), + html.Div([ + # div with dopdown button to select feature to filter + html.Div(dcc.Dropdown( + id={'type': 'var_dropdown', + 'index': index_id}, + options=options, + placeholder="Variable" + ), style={"width": "30%"}), + # div which will contains modalities + html.Div( + id={'type': 'dynamic-output', + 'index': index_id}, + style={"width": "50%"} + ), + # Button to delete bloc + dbc.Button( + id={'type': 'del_dropdown_button', + 'index': index_id}, + children='X', + color='warning', + size='sm' + ) + ], style={'display': 'flex'}) + ] + ) + + return subset_filter + + +def create_filter_modalities_selection(value: str, id: dict, round_dataframe: pd.DataFrame) -> html.Div: + """Create the modalities filter according to the feature type. + + Parameters + ---------- + value : str + feature name + id : dict + id of the filter + round_dataframe : pd.DataFrame + Dataframe + + Returns + ------- + html.Div + Div containing the modalities selection options + """ + if type(round_dataframe[value].iloc[0]) == bool: + new_element = html.Div(dcc.RadioItems( + [{'label': val, 'value': val} for + val in round_dataframe[value].unique()], + id={'type': 'dynamic-bool', + 'index': id['index']}, + value=round_dataframe[value].iloc[0], + inline=False + ), style={"width": "65%", 'margin-left': '20px'}) + elif (type(round_dataframe[value].iloc[0]) == str) | \ + ((type(round_dataframe[value].iloc[0]) == np.int64) & + (len(round_dataframe[value].unique()) <= 20)): + new_element = html.Div(dcc.Dropdown( + id={ + 'type': 'dynamic-str', + 'index': id['index'] + }, + options=[{'label': i, 'value': i} for + i in np.sort(round_dataframe[value].unique())], + multi=True, + ), style={"width": "65%", 'margin-left': '20px'}) + elif ((type(round_dataframe[value].iloc[0]) is pd.Timestamp) | + (type(round_dataframe[value].iloc[0]) is datetime.datetime)): + new_element = html.Div( + dcc.DatePickerRange( + id={ + 'type': 'dynamic-date', + 'index': id['index'] + }, + min_date_allowed=round_dataframe[value].min(), + max_date_allowed=round_dataframe[value].max(), + start_date=round_dataframe[value].min(), + end_date=round_dataframe[value].max() + ), style={'width': '65%', 'margin-left': '20px'}), + else: + lower_value = 0 + upper_value = 0 + new_element = html.Div([ + dcc.Input( + id={ + 'type': 'lower', + 'index': id['index'] + }, + value=lower_value, + type="number", + style={'width': '60px'}), + ' <= {} in [{}, {}]<= '.format( + value, + round_dataframe[value].min(), + round_dataframe[value].max()), + dcc.Input( + id={ + 'type': 'upper', + 'index': id['index'] + }, + value=upper_value, + type="number", + style={'width': '60px'} + ) + ], style={'margin-left': '20px'}) + + return new_element diff --git a/tests/unit_tests/webapp/utils/test_callbacks.py b/tests/unit_tests/webapp/utils/test_callbacks.py index f3123e03..e8711428 100644 --- a/tests/unit_tests/webapp/utils/test_callbacks.py +++ b/tests/unit_tests/webapp/utils/test_callbacks.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd from sklearn.tree import DecisionTreeClassifier +from dash import dcc from shapash import SmartExplainer from shapash.webapp.smart_app import SmartApp from shapash.webapp.utils.callbacks import ( @@ -20,6 +21,8 @@ get_id_card_contrib, create_id_card_data, create_id_card_layout, + get_feature_filter_options, + create_filter_modalities_selection ) @@ -471,4 +474,17 @@ def test_create_id_card_layout(self): index = ['_index_', '_predict_', '_target_', 'column3', 'column1', '_column2'] ) children = create_id_card_layout(selected_data, self.xpl.additional_features_dict) - assert len(children)==6 \ No newline at end of file + assert len(children)==6 + + def test_get_feature_filter_options(self): + features_dict = copy.deepcopy(self.xpl.features_dict) + features_dict.update(self.xpl.additional_features_dict) + options = get_feature_filter_options(self.smart_app.dataframe, features_dict, self.special_cols) + assert [option["label"] for option in options]==['_index_', '_predict_', '_target_', 'Useless col', '_Additional col', 'column1'] + + def test_create_filter_modalities_selection(self): + new_element = create_filter_modalities_selection("column3", {'type': 'var_dropdown', 'index': 1}, self.smart_app.round_dataframe) + assert type(new_element.children[0])==dcc.Input + + new_element = create_filter_modalities_selection("_column2", {'type': 'var_dropdown', 'index': 1}, self.smart_app.round_dataframe) + assert type(new_element.children)==dcc.Dropdown \ No newline at end of file From aa15940c43caa9d9d9024238d433d97617a0543d Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Tue, 4 Jul 2023 11:17:26 +0200 Subject: [PATCH 11/12] Better coverage and delete non useful imports. --- shapash/webapp/smart_app.py | 2 - .../unit_tests/webapp/utils/test_callbacks.py | 62 +++++++++++++++---- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index dcca5861..5bddd587 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -16,8 +16,6 @@ import pandas as pd import plotly.graph_objs as go import random -import numpy as np -import datetime import re from math import log10 from shapash.webapp.utils.utils import check_row, get_index_type, round_to_k diff --git a/tests/unit_tests/webapp/utils/test_callbacks.py b/tests/unit_tests/webapp/utils/test_callbacks.py index e8711428..ad3801b2 100644 --- a/tests/unit_tests/webapp/utils/test_callbacks.py +++ b/tests/unit_tests/webapp/utils/test_callbacks.py @@ -44,7 +44,7 @@ def __init__(self, *args, **kwargs): y_target = pd.DataFrame(data=np.array([0, 0, 0, 1, 1]), columns=['pred']) model = DecisionTreeClassifier().fit(dataframe_x, y_target) features_dict = {'column3': 'Useless col'} - additional_data = df[['column2']].copy() + additional_data = df[['column2','column4','column5']].copy() additional_features_dict = {'column2': 'Additional col'} self.xpl = SmartExplainer(model=model, features_dict=features_dict) self.xpl.compile( @@ -85,6 +85,8 @@ def test_default_init_data(self): 'column1': [1, 2, 3, 4, 5], 'column3': [1.1, 3.3, 2.2, 4.4, 5.5], '_column2': ['a', 'b', 'c', 'd', 'e'], + '_column4': [True, False, True, False, False], + '_column5': pd.date_range('2023-01-01', periods=5), }, ) self.smart_app.init_data() @@ -417,10 +419,19 @@ def test_get_id_card_features(self): selected_row = get_id_card_features(data, 3, self.special_cols, features_dict) expected_result = pd.DataFrame( { - 'feature_value': [3, 1, 1, 4, 4.4, 'd'], - 'feature_name': ['_index_', '_predict_', '_target_', 'column1', 'Useless col', '_Additional col'], + 'feature_value': [3, 1, 1, 4, 4.4, 'd', False, pd.Timestamp('2023-01-04')], + 'feature_name': [ + '_index_', + '_predict_', + '_target_', + 'column1', + 'Useless col', + '_Additional col', + '_column4', + '_column5', + ], }, - index = ['_index_', '_predict_', '_target_', 'column1', 'column3', '_column2'] + index = ['_index_', '_predict_', '_target_', 'column1', 'column3', '_column2', '_column4', '_column5'] ) pd.testing.assert_frame_equal(selected_row, expected_result) @@ -433,10 +444,19 @@ def test_get_id_card_contrib(self): def test_create_id_card_data(self): selected_row = pd.DataFrame( { - 'feature_value': [3, 1, 1, 4, 4.4, 'd'], - 'feature_name': ['_index_', '_predict_', '_target_', 'column1', 'Useless col', '_Additional col'], + 'feature_value': [3, 1, 1, 4, 4.4, 'd', False, pd.Timestamp('2023-01-04')], + 'feature_name': [ + '_index_', + '_predict_', + '_target_', + 'column1', + 'Useless col', + '_Additional col', + '_column4', + '_column5', + ], }, - index = ['_index_', '_predict_', '_target_', 'column1', 'column3', '_column2'] + index = ['_index_', '_predict_', '_target_', 'column1', 'column3', '_column2', '_column4', '_column5'] ) selected_contrib = pd.DataFrame( @@ -456,11 +476,20 @@ def test_create_id_card_data(self): ) expected_result = pd.DataFrame( { - 'feature_value': [3, 1, 1, 4.4, 4, 'd'], - 'feature_name': ['_index_', '_predict_', '_target_', 'Useless col', 'column1', '_Additional col'], - 'feature_contrib': [np.nan, np.nan, np.nan, 0.0, -0.6, np.nan] + 'feature_value': [3, 1, 1, 4.4, 4, 'd', False, pd.Timestamp('2023-01-04')], + 'feature_name': [ + '_index_', + '_predict_', + '_target_', + 'Useless col', + 'column1', + '_Additional col', + '_column4', + '_column5', + ], + 'feature_contrib': [np.nan, np.nan, np.nan, 0.0, -0.6, np.nan, np.nan, np.nan] }, - index = ['_index_', '_predict_', '_target_', 'column3', 'column1', '_column2'] + index = ['_index_', '_predict_', '_target_', 'column3', 'column1', '_column2', '_column4', '_column5'] ) pd.testing.assert_frame_equal(selected_data, expected_result) @@ -480,7 +509,16 @@ def test_get_feature_filter_options(self): features_dict = copy.deepcopy(self.xpl.features_dict) features_dict.update(self.xpl.additional_features_dict) options = get_feature_filter_options(self.smart_app.dataframe, features_dict, self.special_cols) - assert [option["label"] for option in options]==['_index_', '_predict_', '_target_', 'Useless col', '_Additional col', 'column1'] + assert [option["label"] for option in options]==[ + '_index_', + '_predict_', + '_target_', + 'Useless col', + '_Additional col', + '_column4', + '_column5', + 'column1', + ] def test_create_filter_modalities_selection(self): new_element = create_filter_modalities_selection("column3", {'type': 'var_dropdown', 'index': 1}, self.smart_app.round_dataframe) From 38d95400ffbd78687313d0d8ac9a665d34e81868 Mon Sep 17 00:00:00 2001 From: MaximeLecardonnel6x7 Date: Wed, 6 Sep 2023 10:38:26 +0200 Subject: [PATCH 12/12] Correction refacto (trop d'arg, isnan, var max). --- shapash/webapp/smart_app.py | 43 ++-- shapash/webapp/utils/callbacks.py | 201 ++++++++++++------ .../unit_tests/webapp/utils/test_callbacks.py | 143 +++---------- 3 files changed, 194 insertions(+), 193 deletions(-) diff --git a/shapash/webapp/smart_app.py b/shapash/webapp/smart_app.py index 5bddd587..c8ec3f23 100644 --- a/shapash/webapp/smart_app.py +++ b/shapash/webapp/smart_app.py @@ -24,7 +24,10 @@ from shapash.webapp.utils.explanations import Explanations from shapash.webapp.utils.callbacks import ( select_data_from_prediction_picking, - select_data_from_filters, + select_data_from_str_filters, + select_data_from_bool_filters, + select_data_from_numeric_filters, + select_data_from_date_filters, get_feature_from_clicked_data, get_feature_from_features_groups, get_group_name, @@ -1390,8 +1393,6 @@ def callback_generator(self): State({'type': 'lower', 'index': ALL}, 'value'), State({'type': 'lower', 'index': ALL}, 'id'), State({'type': 'upper', 'index': ALL}, 'value'), - State({'type': 'upper', 'index': ALL}, 'id'), - State('dropdowns_container', 'children') ] ) def update_datatable(selected_data, @@ -1412,9 +1413,8 @@ def update_datatable(selected_data, id_date, val_lower_modality, id_lower_modality, - val_upper_modality, - id_upper_modality, - children): + val_upper_modality + ): """ This function is used to update the datatable according to sorting, filtering and settings modifications. @@ -1437,8 +1437,6 @@ def update_datatable(selected_data, val_lower_modality: lower values of numeric filter id_lower_modality: id of lower modalities of numeric filter val_upper_modality: upper values of numeric filter - id_upper_modality: id of upper values of numeric filter - children: children of dropdown container ------------------------------------------------------------------ return data: available dataset @@ -1475,18 +1473,35 @@ def update_datatable(selected_data, (selected_data is not None and len(selected_data) == 1 and len(selected_data['points'])!=0 and selected_data['points'][0]['curveNumber'] > 0) )): - df = select_data_from_filters( - self.round_dataframe, - id_feature, + df = self.round_dataframe.copy() + feature_id = [id_feature[i]['index'] for i in range(len(id_feature))] + df = select_data_from_str_filters( + df, + feature_id, id_str_modality, - id_bool_modality, - id_lower_modality, - id_date, val_feature, val_str_modality, + ) + df = select_data_from_bool_filters( + df, + feature_id, + id_bool_modality, + val_feature, val_bool_modality, + ) + df = select_data_from_numeric_filters( + df, + feature_id, + id_lower_modality, + val_feature, val_lower_modality, val_upper_modality, + ) + df = select_data_from_date_filters( + df, + feature_id, + id_date, + val_feature, start_date, end_date, ) diff --git a/shapash/webapp/utils/callbacks.py b/shapash/webapp/utils/callbacks.py index 79d8103c..3275e29c 100644 --- a/shapash/webapp/utils/callbacks.py +++ b/shapash/webapp/utils/callbacks.py @@ -30,47 +30,152 @@ def select_data_from_prediction_picking(round_dataframe: pd.DataFrame, selected_ return df -def select_data_from_filters( - round_dataframe: pd.DataFrame, - id_feature: list, +def select_data_from_str_filters( + df: pd.DataFrame, + feature_id: list, id_str_modality: list, - id_bool_modality: list, - id_lower_modality: list, - id_date: list, val_feature: list, val_str_modality: list, - val_bool_modality: list, - val_lower_modality: list, - val_upper_modality: list, - start_date: list, - end_date: list, ) -> pd.DataFrame: """Create a subset dataframe from filters. Parameters ---------- - round_dataframe : pd.DataFrame + df : pd.DataFrame Data to sample - id_feature : list + feature_id : list features ids id_str_modality : list string features ids - id_bool_modality : list - boolean features ids - id_lower_modality : list - numeric features ids - id_date : list - date features ids val_feature : list features names val_str_modality : list string modalities selected + + Returns + ------- + pd.DataFrame + Subset dataframe + """ + # get list of ID + str_id = [id_str_modality[i]['index'] for i in range(len(id_str_modality))] + # If there is some filters + if len(str_id) > 0: + for i in range(len(feature_id)): + if feature_id[i] in str_id: + position = np.where(np.array(str_id) == feature_id[i])[0][0] + if ((position is not None) & (val_str_modality[position] is not None)): + df = df[df[val_feature[i]].isin(val_str_modality[position])] + + return df + + +def select_data_from_bool_filters( + df: pd.DataFrame, + feature_id: list, + id_bool_modality: list, + val_feature: list, + val_bool_modality: list, +) -> pd.DataFrame: + """Create a subset dataframe from filters. + + Parameters + ---------- + df : pd.DataFrame + Data to sample + feature_id : list + features ids + id_bool_modality : list + boolean features ids + val_feature : list + features names val_bool_modality : list boolean modalities selected + + Returns + ------- + pd.DataFrame + Subset dataframe + """ + # get list of ID + bool_id = [id_bool_modality[i]['index'] for i in range(len(id_bool_modality))] + # If there is some filters + if len(bool_id) > 0: + for i in range(len(feature_id)): + if feature_id[i] in bool_id: + position = np.where(np.array(bool_id) == feature_id[i])[0][0] + if ((position is not None) & (val_bool_modality[position] is not None)): + df = df[df[val_feature[i]] == val_bool_modality[position]] + + return df + + +def select_data_from_numeric_filters( + df: pd.DataFrame, + feature_id: list, + id_lower_modality: list, + val_feature: list, + val_lower_modality: list, + val_upper_modality: list, +) -> pd.DataFrame: + """Create a subset dataframe from filters. + + Parameters + ---------- + df : pd.DataFrame + Data to sample + feature_id : list + features ids + id_lower_modality : list + numeric features ids + val_feature : list + features names val_lower_modality : list lower values of numeric filter val_upper_modality : list upper values of numeric filter + + Returns + ------- + pd.DataFrame + Subset dataframe + """ + # get list of ID + lower_id = [id_lower_modality[i]['index'] for i in range(len(id_lower_modality))] + # If there is some filters + if len(lower_id) > 0: + for i in range(len(feature_id)): + if feature_id[i] in lower_id: + position = np.where(np.array(lower_id) == feature_id[i])[0][0] + if((position is not None) & (val_lower_modality[position] is not None) & + (val_upper_modality[position] is not None)): + if (val_lower_modality[position] < val_upper_modality[position]): + df = df[(df[val_feature[i]] >= val_lower_modality[position]) & + (df[val_feature[i]] <= val_upper_modality[position])] + + return df + + +def select_data_from_date_filters( + df: pd.DataFrame, + feature_id: list, + id_date: list, + val_feature: list, + start_date: list, + end_date: list, +) -> pd.DataFrame: + """Create a subset dataframe from filters. + + Parameters + ---------- + round_dataframe : pd.DataFrame + Data to sample + id_feature : list + features ids + id_date : list + date features ids + val_feature : list + features names start_date : list start dates selected end_date : list @@ -82,41 +187,17 @@ def select_data_from_filters( Subset dataframe """ # get list of ID - feature_id = [id_feature[i]['index'] for i in range(len(id_feature))] - str_id = [id_str_modality[i]['index'] for i in range(len(id_str_modality))] - bool_id = [id_bool_modality[i]['index'] for i in range(len(id_bool_modality))] - lower_id = [id_lower_modality[i]['index'] for i in range(len(id_lower_modality))] date_id = [id_date[i]['index'] for i in range(len(id_date))] - df = round_dataframe # If there is some filters - if len(feature_id) > 0: + if len(date_id) > 0: for i in range(len(feature_id)): - # String filter - if feature_id[i] in str_id: - position = np.where(np.array(str_id) == feature_id[i])[0][0] - if ((position is not None) & (val_str_modality[position] is not None)): - df = df[df[val_feature[i]].isin(val_str_modality[position])] - # Boolean filter - elif feature_id[i] in bool_id: - position = np.where(np.array(bool_id) == feature_id[i])[0][0] - if ((position is not None) & (val_bool_modality[position] is not None)): - df = df[df[val_feature[i]] == val_bool_modality[position]] - # Date filter - elif feature_id[i] in date_id: + if feature_id[i] in date_id: position = np.where(np.array(date_id) == feature_id[i])[0][0] if((position is not None) & (start_date[position] < end_date[position])): df = df[((df[val_feature[i]] >= start_date[position]) & (df[val_feature[i]] <= end_date[position]))] - # Numeric filter - elif feature_id[i] in lower_id: - position = np.where(np.array(lower_id) == feature_id[i])[0][0] - if((position is not None) & (val_lower_modality[position] is not None) & - (val_upper_modality[position] is not None)): - if (val_lower_modality[position] < val_upper_modality[position]): - df = df[(df[val_feature[i]] >= val_lower_modality[position]) & - (df[val_feature[i]] <= val_upper_modality[position])] - + return df @@ -283,24 +364,24 @@ def update_features_to_display(features: int, nb_columns: int, value: int) -> Tu Tuple[int, int, dict] Number of columns to plot, Number max of columns to plot, Marks in the slider """ - max = min(features, nb_columns) - if max % 5 == 0: - nb_marks = min(int(max // 5), 10) - elif max % 4 == 0: - nb_marks = min(int(max // 4), 10) - elif max % 3 == 0: - nb_marks = min(int(max // 3), 10) - elif max % 7 == 0: - nb_marks = min(int(max // 7), 10) + max_value = min(features, nb_columns) + if max_value % 5 == 0: + nb_marks = min(int(max_value // 5), 10) + elif max_value % 4 == 0: + nb_marks = min(int(max_value // 4), 10) + elif max_value % 3 == 0: + nb_marks = min(int(max_value // 3), 10) + elif max_value % 7 == 0: + nb_marks = min(int(max_value // 7), 10) else: nb_marks = 2 - marks = {f'{round(max * feat / nb_marks)}': f'{round(max * feat / nb_marks)}' + marks = {f'{round(max_value * feat / nb_marks)}': f'{round(max_value * feat / nb_marks)}' for feat in range(1, nb_marks + 1)} marks['1'] = '1' - if max < value: - value = max + if max_value < value: + value = max_value - return value, max, marks + return value, max_value, marks def get_id_card_features(data: list, selected: int, special_cols: list, features_dict: dict) -> pd.DataFrame: @@ -444,7 +525,7 @@ def create_id_card_layout(selected_data: pd.DataFrame, additional_features_dict: ), width=2, className="id_card_solid", - ) if row["feature_contrib"]==row["feature_contrib"] else None, + ) if not np.isnan(row["feature_contrib"]) else None, ]) ) diff --git a/tests/unit_tests/webapp/utils/test_callbacks.py b/tests/unit_tests/webapp/utils/test_callbacks.py index ad3801b2..ec336843 100644 --- a/tests/unit_tests/webapp/utils/test_callbacks.py +++ b/tests/unit_tests/webapp/utils/test_callbacks.py @@ -8,7 +8,10 @@ from shapash.webapp.smart_app import SmartApp from shapash.webapp.utils.callbacks import ( select_data_from_prediction_picking, - select_data_from_filters, + select_data_from_str_filters, + select_data_from_bool_filters, + select_data_from_numeric_filters, + select_data_from_date_filters, get_feature_from_clicked_data, get_feature_from_features_groups, get_group_name, @@ -111,20 +114,13 @@ def test_select_data_from_prediction_picking(self): result = select_data_from_prediction_picking(self.df, selected_data) pd.testing.assert_frame_equal(expected_result, result) - def test_select_data_from_filters_string(self): + def test_select_data_from_str_filters(self): round_dataframe = self.df id_feature = [{'type': 'var_dropdown', 'index': 1}] + feature_id = [id_feature[i]['index'] for i in range(len(id_feature))] id_str_modality = [{'type': 'dynamic-str', 'index': 1}] - id_bool_modality = [] - id_lower_modality = [] - id_date = [] val_feature = ['column2'] val_str_modality = [['a', 'c']] - val_bool_modality = [] - val_lower_modality = [] - val_upper_modality = [] - start_date = [] - end_date = [] expected_result = pd.DataFrame( { @@ -136,37 +132,22 @@ def test_select_data_from_filters_string(self): }, index=[0, 2] ) - result = select_data_from_filters( + result = select_data_from_str_filters( round_dataframe, - id_feature, + feature_id, id_str_modality, - id_bool_modality, - id_lower_modality, - id_date, val_feature, val_str_modality, - val_bool_modality, - val_lower_modality, - val_upper_modality, - start_date, - end_date, ) pd.testing.assert_frame_equal(expected_result, result) - def test_select_data_from_filters_bool(self): + def test_select_data_from_bool_filters(self): round_dataframe = self.df id_feature = [{'type': 'var_dropdown', 'index': 2}] - id_str_modality = [] + feature_id = [id_feature[i]['index'] for i in range(len(id_feature))] id_bool_modality = [{'type': 'dynamic-bool', 'index': 2}] - id_lower_modality = [] - id_date = [] val_feature = ['column4'] - val_str_modality = [] val_bool_modality = [True] - val_lower_modality = [] - val_upper_modality = [] - start_date = [] - end_date = [] expected_result = pd.DataFrame( { @@ -178,35 +159,21 @@ def test_select_data_from_filters_bool(self): }, index=[0, 2] ) - result = select_data_from_filters( + result = select_data_from_bool_filters( round_dataframe, - id_feature, - id_str_modality, - id_bool_modality, - id_lower_modality, - id_date, + feature_id, + id_bool_modality, val_feature, - val_str_modality, val_bool_modality, - val_lower_modality, - val_upper_modality, - start_date, - end_date, ) pd.testing.assert_frame_equal(expected_result, result) - def test_select_data_from_filters_date(self): + def test_select_data_from_date_filters(self): round_dataframe = self.df id_feature = [{'type': 'var_dropdown', 'index': 1}] - id_str_modality = [] - id_bool_modality = [] - id_lower_modality = [] + feature_id = [id_feature[i]['index'] for i in range(len(id_feature))] id_date = [{'type': 'dynamic-date', 'index': 1}] val_feature = ['column5'] - val_str_modality = [] - val_bool_modality = [] - val_lower_modality = [] - val_upper_modality = [] start_date = [pd.Timestamp('2023-01-01')] end_date = [pd.Timestamp('2023-01-03')] @@ -220,79 +187,24 @@ def test_select_data_from_filters_date(self): }, index=[0, 1, 2] ) - result = select_data_from_filters( + result = select_data_from_date_filters( round_dataframe, - id_feature, - id_str_modality, - id_bool_modality, - id_lower_modality, + feature_id, id_date, - val_feature, - val_str_modality, - val_bool_modality, - val_lower_modality, - val_upper_modality, + val_feature, start_date, end_date, ) pd.testing.assert_frame_equal(expected_result, result) - def test_select_data_from_filters_numeric(self): + def test_select_data_from_numeric_filters(self): round_dataframe = self.df id_feature = [{'type': 'var_dropdown', 'index': 1}, {'type': 'var_dropdown', 'index': 2}] - id_str_modality = [] - id_bool_modality = [] + feature_id = [id_feature[i]['index'] for i in range(len(id_feature))] id_lower_modality = [{'type': 'lower', 'index': 1}, {'type': 'lower', 'index': 2}] - id_date = [] val_feature = ['column1', 'column3'] - val_str_modality = [] - val_bool_modality = [] val_lower_modality = [0, 0] val_upper_modality = [3, 3] - start_date = [] - end_date = [] - - expected_result = pd.DataFrame( - { - 'column1': [1, 3], - 'column2': ['a', 'c'], - 'column3': [1.1, 2.2], - 'column4': [True, True], - 'column5': [pd.Timestamp('2023-01-01'), pd.Timestamp('2023-01-03')], - }, - index=[0, 2] - ) - result = select_data_from_filters( - round_dataframe, - id_feature, - id_str_modality, - id_bool_modality, - id_lower_modality, - id_date, - val_feature, - val_str_modality, - val_bool_modality, - val_lower_modality, - val_upper_modality, - start_date, - end_date, - ) - pd.testing.assert_frame_equal(expected_result, result) - - def test_select_data_from_filters_multi_types(self): - round_dataframe = self.df - id_feature = [{'type': 'var_dropdown', 'index': 1}, {'type': 'var_dropdown', 'index': 2}] - id_str_modality = [{'type': 'dynamic-str', 'index': 2}] - id_bool_modality = [] - id_lower_modality = [{'type': 'lower', 'index': 1}] - id_date = [] - val_feature = ['column1', 'column2'] - val_str_modality = [['a', 'c', 'd', 'e']] - val_bool_modality = [] - val_lower_modality = [0] - val_upper_modality = [3] - start_date = [] - end_date = [] expected_result = pd.DataFrame( { @@ -304,20 +216,13 @@ def test_select_data_from_filters_multi_types(self): }, index=[0, 2] ) - result = select_data_from_filters( + result = select_data_from_numeric_filters( round_dataframe, - id_feature, - id_str_modality, - id_bool_modality, - id_lower_modality, - id_date, - val_feature, - val_str_modality, - val_bool_modality, + feature_id, + id_lower_modality, + val_feature, val_lower_modality, val_upper_modality, - start_date, - end_date, ) pd.testing.assert_frame_equal(expected_result, result)