Skip to content

Commit

Permalink
Set proba_values in the method add of the smart_explainer
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-vignal committed Mar 11, 2024
1 parent 974ba0a commit 0dad3f7
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions shapash/explainer/smart_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,14 @@ def __init__(
self.features_imp = None

def compile(
self, x, contributions=None, y_pred=None, y_target=None, additional_data=None, additional_features_dict=None
self,
x,
contributions=None,
y_pred=None,
proba_values=None,
y_target=None,
additional_data=None,
additional_features_dict=None,
):
"""
The compile method is the first step to understand model and
Expand All @@ -266,6 +273,11 @@ def compile(
This is an interesting parameter for more explicit outputs.
Shapash lets users define their own predict,
as they may wish to set their own threshold (classification)
proba_values : pandas.Series or pandas.DataFrame, optional (default: None)
Probability values (1 column only).
The index must be identical to the index of x_init.
This is an interesting parameter for more explicit outputs.
Shapash lets users define their own probability values
y_target : pandas.Series or pandas.DataFrame, optional (default: None)
Target values (1 column only).
The index must be identical to the index of x_init.
Expand All @@ -291,14 +303,12 @@ def compile(
x_init = inverse_transform(self.x_encoded, self.preprocessing)
self.x_init = handle_categorical_missing(x_init)
self.y_pred = check_y(self.x_init, y_pred, y_name="y_pred")
if not hasattr(self, "y_pred") or self.y_pred is None:
if hasattr(self.model, "predict"):
self.predict()
if self._case == "classification":
if hasattr(self.model, "predict_proba"):
self.predict_proba()
else:
self.proba_values = None
if (self.y_pred is None) and (hasattr(self.model, "predict")):
self.predict()

self.proba_values = check_y(self.x_init, proba_values, y_name="proba_values")
if (self._case == "classification") and (self.proba_values is None) and (hasattr(self.model, "predict_proba")):
self.predict_proba()

self.y_target = check_y(self.x_init, y_target, y_name="y_target")
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)
Expand Down Expand Up @@ -414,6 +424,7 @@ def define_style(self, palette_name=None, colors_dict=None):
def add(
self,
y_pred=None,
proba_values=None,
y_target=None,
label_dict=None,
features_dict=None,
Expand All @@ -432,6 +443,9 @@ def add(
y_pred : pandas.Series, optional (default: None)
Prediction values (1 column only).
The index must be identical to the index of x_init.
proba_values : pandas.Series, optional (default: None)
Probability values (1 column only).
The index must be identical to the index of x_init.
label_dict: dict, optional (default: None)
Dictionary mapping integer labels to domain names.
features_dict: dict, optional (default: None)
Expand All @@ -455,6 +469,8 @@ def add(
self.y_pred = check_y(self.x_init, y_pred, y_name="y_pred")
if hasattr(self, "y_target"):
self.prediction_error = predict_error(self.y_target, self.y_pred, self._case)
if proba_values is not None:
self.proba_values = check_y(self.x_init, proba_values, y_name="proba_values")
if y_target is not None:
self.y_target = check_y(self.x_init, y_target, y_name="y_target")
if hasattr(self, "y_pred"):
Expand Down Expand Up @@ -904,6 +920,7 @@ def to_pandas(
)
# Matching with y_pred
if proba:
self.predict_proba()
proba_values = self.proba_values
else:
proba_values = None
Expand Down

0 comments on commit 0dad3f7

Please sign in to comment.