Skip to content

Commit

Permalink
Merge pull request #579 from guillaume-vignal/feature/new_feature_imp…
Browse files Browse the repository at this point in the history
…ortance

New Feature: Subpopulation-based Feature Importance Plots
  • Loading branch information
guillaume-vignal authored Sep 17, 2024
2 parents 41ec571 + c6fe611 commit d03dca1
Show file tree
Hide file tree
Showing 16 changed files with 1,298 additions and 418 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ repos:
exclude: ^(docs/|gdocs/)
- id: check-added-large-files
args: ['--maxkb=500']
- id: no-commit-to-branch
args: ['--branch', 'master', '--branch', 'develop']
- id: no-commit-to-branch
args: ['--branch', 'master', '--branch', 'develop']
- repo: https://github.com/psf/black
rev: 21.12b0
hooks:
Expand Down
4 changes: 2 additions & 2 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ build:

# Optionally set the version of Python and requirements required to build your docs
python:
install:
- requirements: docs/requirements.txt
install:
- requirements: docs/requirements.txt
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
version=version_d["__version__"],
python_requires=">3.8, <3.13",
url="https://github.com/MAIF/shapash",
author="Yann Golhen, Sebastien Bidault, Yann Lagre, Maxime Gendre",
author="Yann Golhen, Sebastien Bidault, Yann Lagre, Maxime Gendre, Thomas Bouché, Maxime Lecardonnel, Guillaume Vignal",
author_email="[email protected]",
description="Shapash is a Python library which aims to make machine learning interpretable and understandable by everyone.",
long_description=long_description,
Expand Down
8 changes: 6 additions & 2 deletions shapash/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ def get_local_contributions(
return local_contributions

def get_global_features_importance(
self, contributions: pd.DataFrame, explain_data: Optional[dict] = None, subset: Optional[List[int]] = None
self,
contributions: pd.DataFrame,
explain_data: Optional[dict] = None,
subset: Optional[List[int]] = None,
norm: int = 1,
) -> Union[pd.Series, List[pd.Series]]:
"""Get global contributions using the explainer data computed in the `run_explainer`
method.
Expand All @@ -132,7 +136,7 @@ def get_global_features_importance(
contributions = [c.loc[subset] for c in contributions]
else:
contributions = contributions.loc[subset]
return state.compute_features_import(contributions)
return state.compute_features_import(contributions, norm)

def format_and_aggregate_local_contributions(
self,
Expand Down
10 changes: 5 additions & 5 deletions shapash/explainer/multi_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,23 +226,23 @@ def summarize(self, s_contribs, var_dicts, xs_sorted, masks, columns_dict, featu
arg_tup = list(zip(s_contribs, var_dicts, xs_sorted, masks))
return self.delegate("summarize", arg_tup, columns_dict, features_dict)

def compute_features_import(self, contributions):
def compute_features_import(self, contributions, norm=1):
"""
Compute a relative features importance, sum of absolute values
​​of the contributions for each
features importance compute in base 100
​​of the contributions for each
features importance compute in base 100
Parameters
----------
contributions : list
list of pandas.DataFrames containing contributions
list of pandas.DataFrames containing contributions
Returns
-------
list
list of features importance pandas.series
"""
return self.delegate("compute_features_import", contributions)
return self.delegate("compute_features_import", contributions, norm)

def compute_grouped_contributions(self, contributions, features_groups):
"""
Expand Down
27 changes: 21 additions & 6 deletions shapash/explainer/smart_explainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Smart explainer module
"""

import copy
import logging
import shutil
Expand Down Expand Up @@ -217,13 +218,12 @@ def __init__(
self.backend_kwargs = backend_kwargs
self.features_dict = dict() if features_dict is None else copy.deepcopy(features_dict)
self.label_dict = label_dict
self.plot = SmartPlotter(self)
self.title_story = title_story if title_story is not None else ""
self.palette_name = palette_name if palette_name else "default"
self.colors_dict = copy.deepcopy(select_palette(colors_loading(), self.palette_name))
if colors_dict is not None:
self.colors_dict.update(colors_dict)
self.plot.define_style_attributes(colors_dict=self.colors_dict)
self.plot = SmartPlotter(self, self.colors_dict)

self._case, self._classes = check_model(self.model)
self.postprocessing = postprocessing
Expand Down Expand Up @@ -359,7 +359,7 @@ def _compile_features_groups(self, features_groups):
Performs required computations for groups of features.
"""
if self.backend.support_groups is False:
raise AssertionError(f"Selected backend ({self.backend.name}) " f"does not support groups of features.")
raise AssertionError(f"Selected backend ({self.backend.name}) does not support groups of features.")
# Compute contributions for groups of features
self.contributions_groups = self.state.compute_grouped_contributions(self.contributions, features_groups)
self.features_imp_groups = None
Expand Down Expand Up @@ -931,7 +931,7 @@ def to_pandas(

return pd.concat([y_pred, summary], axis=1)

def compute_features_import(self, force=False):
def compute_features_import(self, force=False, local=False):
"""
Compute a relative features importance, sum of absolute values
of the contributions for each.
Expand All @@ -949,11 +949,26 @@ def compute_features_import(self, force=False):
index of the serie = contributions.columns
"""
self.features_imp = self.backend.get_global_features_importance(
contributions=self.contributions, explain_data=self.explain_data, subset=None
contributions=self.contributions, explain_data=self.explain_data, subset=None, norm=1
)

if local:
self.features_imp_local_lev1 = self.backend.get_global_features_importance(
contributions=self.contributions, explain_data=self.explain_data, subset=None, norm=3
)
self.features_imp_local_lev2 = self.backend.get_global_features_importance(
contributions=self.contributions, explain_data=self.explain_data, subset=None, norm=7
)

if self.features_groups is not None and self.features_imp_groups is None:
self.features_imp_groups = self.state.compute_features_import(self.contributions_groups)
self.features_imp_groups = self.state.compute_features_import(self.contributions_groups, norm=1)
if local:
self.features_imp_groups_local_lev1 = self.state.compute_features_import(
self.contributions_groups, norm=3
)
self.features_imp_groups_local_lev2 = self.state.compute_features_import(
self.contributions_groups, norm=7
)

def compute_features_stability(self, selection):
"""
Expand Down
Loading

0 comments on commit d03dca1

Please sign in to comment.