Skip to content

Commit

Permalink
Merge pull request #574 from guillaume-vignal/feature/new_feature_imp…
Browse files Browse the repository at this point in the history
…ortance

Feature/new feature importance
  • Loading branch information
guillaume-vignal authored Sep 3, 2024
2 parents e659e72 + 18565d8 commit 7187c96
Show file tree
Hide file tree
Showing 6 changed files with 415 additions and 118 deletions.
38 changes: 31 additions & 7 deletions shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@ def _create_jittered_points(
- A numpy array of jittered points.
"""
# Creating jittered points
jitter = np.random.normal(mean, std, len(percentages))
rng = np.random.default_rng(seed=79)
jitter = rng.normal(mean, std, len(percentages))
if np.isnan(percentages).any():
percentages.fill(1)

Expand Down Expand Up @@ -583,7 +584,8 @@ def prepare_hover_text(self, feature_values, pred, feature_name):
def _add_violin_trace(self, fig, name, x, y, side, line_color, hovertext, secondary_y=True):
"""Adds a Violin trace to the figure."""
# Violin plot has a problem if for one violin all the points have the same contribution value
y = y + np.random.normal(size=y.shape) * (max(y.max(), 0) - min(y.min(), 0)) / 10 ** 8
rng = np.random.default_rng(seed=79)
y = y + rng.normal(size=y.shape) * (max(y.max(), 0) - min(y.min(), 0)) / 10 ** 8
violin_trace = go.Violin(
name=name,
x=x,
Expand Down Expand Up @@ -1718,6 +1720,7 @@ def contribution_plot(
def features_importance(
self,
max_features=20,
page="top",
selection=None,
label=-1,
group_name=None,
Expand All @@ -1742,6 +1745,8 @@ def features_importance(
max_features: int (optional, default 20)
this argument limit the number of hbar in features importance plot
if max_features is 20, plot selects the 20 most important features
page: int, str (optional, default "top")
enable to select the features to plot between "top", "worse" or the number of the page
selection: list (optional, default None)
This argument allows to represent the importance calculated with a subset.
Subset features importance is compared to global in the plot
Expand Down Expand Up @@ -1777,6 +1782,24 @@ def features_importance(
--------
>>> xpl.plot.features_importance()
"""

def get_feature_importance_page(features_importance, page, max_features):
if isinstance(page, int):
nb_features = len(features_importance)
nb_page_max = nb_features // max_features + 1
page = (page - 1) % nb_page_max + 1

if (page == "top") or (page == 1):
return features_importance.tail(max_features)
elif page == "worst":
return features_importance.head(max_features)
elif isinstance(page, int):
start_index = (page - 1) * max_features
end_index = start_index + max_features
return features_importance.iloc[-end_index:-start_index]
else:
raise ValueError("Invalid value for page. It must be 'top', 'worst', or an integer.")

self.explainer.compute_features_import(force=force)
subtitle = None
title = "Features Importance"
Expand Down Expand Up @@ -1809,17 +1832,18 @@ def features_importance(
# classification
if self.explainer._case == "classification":
label_num, _, label_value = self.explainer.check_label_name(label)
global_feat_imp = features_importance[label_num].tail(max_features)
global_feat_imp = get_feature_importance_page(features_importance[label_num], page, max_features)
if selection is not None:
subset_feat_imp = self.explainer.backend.get_global_features_importance(
contributions=contributions[label_num], explain_data=self.explainer.explain_data, subset=selection
)
else:
subset_feat_imp = None
subtitle = f"Response: <b>{label_value}</b>"

# regression
elif self.explainer._case == "regression":
global_feat_imp = features_importance.tail(max_features)
global_feat_imp = get_feature_importance_page(features_importance, page, max_features)
if selection is not None:
subset_feat_imp = self.explainer.backend.get_global_features_importance(
contributions=contributions, explain_data=self.explainer.explain_data, subset=selection
Expand Down Expand Up @@ -3652,10 +3676,10 @@ def _prediction_classification_plot(
df_wrong_predict.target.values.flatten(),
)
]

rng = np.random.default_rng(seed=79)
fig.add_trace(
go.Scatter(
x=df_correct_predict["target"].values.flatten() + np.random.normal(0, 0.02, len(df_correct_predict)),
x=df_correct_predict["target"].values.flatten() + rng.normal(0, 0.02, len(df_correct_predict)),
y=df_correct_predict["proba_values"].values.flatten(),
mode="markers",
marker_color=self._style_dict["prediction_plot"][1],
Expand All @@ -3669,7 +3693,7 @@ def _prediction_classification_plot(

fig.add_trace(
go.Scatter(
x=df_wrong_predict["target"].values.flatten() + np.random.normal(0, 0.02, len(df_wrong_predict)),
x=df_wrong_predict["target"].values.flatten() + rng.normal(0, 0.02, len(df_wrong_predict)),
y=df_wrong_predict["proba_values"].values.flatten(),
mode="markers",
marker_color=self._style_dict["prediction_plot"][0],
Expand Down
7 changes: 4 additions & 3 deletions shapash/utils/explanation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def _get_radius(dataset, n_neighbors, sample_size=500, percentile=95):
# Select 500 points max to sample
size = min([dataset.shape[0], sample_size])
# Randomly sample points from dataset
sampled_instances = dataset[np.random.randint(0, dataset.shape[0], size), :]
rng = np.random.default_rng(seed=79)
sampled_instances = dataset[rng.integers(0, dataset.shape[0], size), :]
# Define normalization vector
mean_vector = np.array(dataset, dtype=np.float32).std(axis=0)
# Initialize the similarity matrix
Expand All @@ -109,9 +110,9 @@ def _get_radius(dataset, n_neighbors, sample_size=500, percentile=95):
similarity_distance[i, j] = dist
similarity_distance[j, i] = dist
# Select top n_neighbors
ordered_X = np.sort(similarity_distance)[:, 1 : n_neighbors + 1]
ordered_x = np.sort(similarity_distance)[:, 1 : n_neighbors + 1]
# Select the value of the distance that captures XX% of all distances (percentile)
return np.percentile(ordered_X.flatten(), percentile)
return np.percentile(ordered_x.flatten(), percentile)


def find_neighbors(selection, dataset, model, mode, n_neighbors=10):
Expand Down
Loading

0 comments on commit 7187c96

Please sign in to comment.