Skip to content

Commit

Permalink
Merge pull request #535 from guillaume-vignal/feature/compile_preds
Browse files Browse the repository at this point in the history
compute predictions and probabilities in compile
  • Loading branch information
guillaume-vignal authored Mar 28, 2024
2 parents 9be4d36 + 0dad3f7 commit 2be43bd
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
33 changes: 27 additions & 6 deletions shapash/explainer/smart_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,14 @@ def __init__(
self.features_imp = None

def compile(
self, x, contributions=None, y_pred=None, y_target=None, additional_data=None, additional_features_dict=None
self,
x,
contributions=None,
y_pred=None,
proba_values=None,
y_target=None,
additional_data=None,
additional_features_dict=None,
):
"""
The compile method is the first step to understand model and
Expand All @@ -266,6 +273,11 @@ def compile(
This is an interesting parameter for more explicit outputs.
Shapash lets users define their own predict,
as they may wish to set their own threshold (classification)
proba_values : pandas.Series or pandas.DataFrame, optional (default: None)
Probability values (1 column only).
The index must be identical to the index of x_init.
This is an interesting parameter for more explicit outputs.
Shapash lets users define their own probability values
y_target : pandas.Series or pandas.DataFrame, optional (default: None)
Target values (1 column only).
The index must be identical to the index of x_init.
Expand All @@ -291,6 +303,13 @@ def compile(
x_init = inverse_transform(self.x_encoded, self.preprocessing)
self.x_init = handle_categorical_missing(x_init)
self.y_pred = check_y(self.x_init, y_pred, y_name="y_pred")
if (self.y_pred is None) and (hasattr(self.model, "predict")):
self.predict()

self.proba_values = check_y(self.x_init, proba_values, y_name="proba_values")
if (self._case == "classification") and (self.proba_values is None) and (hasattr(self.model, "predict_proba")):
self.predict_proba()

self.y_target = check_y(self.x_init, y_target, y_name="y_target")
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)

Expand Down Expand Up @@ -405,6 +424,7 @@ def define_style(self, palette_name=None, colors_dict=None):
def add(
self,
y_pred=None,
proba_values=None,
y_target=None,
label_dict=None,
features_dict=None,
Expand All @@ -423,6 +443,9 @@ def add(
y_pred : pandas.Series, optional (default: None)
Prediction values (1 column only).
The index must be identical to the index of x_init.
proba_values : pandas.Series, optional (default: None)
Probability values (1 column only).
The index must be identical to the index of x_init.
label_dict: dict, optional (default: None)
Dictionary mapping integer labels to domain names.
features_dict: dict, optional (default: None)
Expand All @@ -446,6 +469,8 @@ def add(
self.y_pred = check_y(self.x_init, y_pred, y_name="y_pred")
if hasattr(self, "y_target"):
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)
if proba_values is not None:
self.proba_values = check_y(self.x_init, proba_values, y_name="proba_values")
if y_target is not None:
self.y_target = check_y(self.x_init, y_target, y_name="y_target")
if hasattr(self, "y_pred"):
Expand Down Expand Up @@ -895,7 +920,7 @@ def to_pandas(
)
# Matching with y_pred
if proba:
self.predict_proba() if proba else None
self.predict_proba()
proba_values = self.proba_values
else:
proba_values = None
Expand Down Expand Up @@ -1006,8 +1031,6 @@ def init_app(self, settings: dict = None):
Possible settings (dict keys) are 'rows', 'points', 'violin', 'features'
Values should be positive ints
"""
if self.y_pred is None:
self.predict()
self.smartapp = SmartApp(self, settings)

def run_app(
Expand Down Expand Up @@ -1046,8 +1069,6 @@ def run_app(

if title_story is not None:
self.title_story = title_story
if self.y_pred is None:
self.predict()
if hasattr(self, "_case"):
self.smartapp = SmartApp(self, settings)
if host is None:
Expand Down
18 changes: 3 additions & 15 deletions shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,9 +949,7 @@ def local_pred(self, index, label=None):
float: Predict or predict_proba value
"""
if self.explainer._case == "classification":
if hasattr(self.explainer.model, "predict_proba"):
if not hasattr(self.explainer, "proba_values"):
self.explainer.predict_proba()
if self.explainer.proba_values is not None:
value = self.explainer.proba_values.iloc[:, [label]].loc[index].values[0]
else:
value = None
Expand Down Expand Up @@ -1237,9 +1235,7 @@ def contribution_plot(
col_value = self.explainer._classes[label_num]
subtitle = f"Response: <b>{label_value}</b>"
# predict proba Color scale
if proba and hasattr(self.explainer.model, "predict_proba"):
if not hasattr(self.explainer, "proba_values"):
self.explainer.predict_proba()
if proba and self.explainer.proba_values is not None:
proba_values = self.explainer.proba_values.iloc[:, [label_num]]
if not hasattr(self, "pred_colorscale"):
self.pred_colorscale = {}
Expand Down Expand Up @@ -3209,12 +3205,7 @@ def _prediction_classification_plot(

label_num, _, label_value = self.explainer.check_label_name(label)
# predict proba Color scale
if hasattr(self.explainer.model, "predict_proba"):
if not hasattr(self.explainer, "proba_values"):
self.explainer.predict_proba()
if hasattr(self.explainer.model, "predict"):
if not hasattr(self.explainer, "y_pred") or self.explainer.y_pred is None:
self.explainer.predict()
if self.explainer.proba_values is not None:
# Assign proba values of the target
df_proba_target = self.explainer.proba_values.copy()
df_proba_target["proba_target"] = df_proba_target.iloc[:, label_num]
Expand Down Expand Up @@ -3333,9 +3324,6 @@ def _prediction_regression_plot(
fig = go.Figure()

subtitle = None
if self.explainer.y_pred is None:
if hasattr(self.explainer.model, "predict"):
self.explainer.predict()
prediction_error = self.explainer.prediction_error
if prediction_error is not None:
if (self.explainer.y_target == 0).any()[0]:
Expand Down

0 comments on commit 2be43bd

Please sign in to comment.