Skip to content

Commit

Permalink
Merge pull request #582 from guillaume-vignal/feature/smartplotter_si…
Browse files Browse the repository at this point in the history
…mplification

SmartPlotter simplification by delegating each plot type to a separate function file
  • Loading branch information
guillaume-vignal authored Oct 4, 2024
2 parents d03dca1 + e972df3 commit 8a50775
Show file tree
Hide file tree
Showing 17 changed files with 3,661 additions and 2,981 deletions.
29 changes: 29 additions & 0 deletions shapash/explainer/smart_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def compile(
else self._compile_additional_features_dict(additional_features_dict)
)
self.additional_data = self._compile_additional_data(additional_data)
self.plot._tuning_round_digit()

def _get_contributions_from_backend_or_user(self, x, contributions):
# Computing contributions using backend
Expand Down Expand Up @@ -1320,3 +1321,31 @@ def generate_report(
if rm_working_dir:
shutil.rmtree(working_dir)
raise e

def _local_pred(self, index, label=None):
"""
compute a local pred to display in local_plot
Parameters
----------
index: string, int, float, ...
specify the row we want to pred
label: int (default: None)
Returns
-------
float: Predict or predict_proba value
"""
if self._case == "classification":
if self.proba_values is not None:
value = self.proba_values.iloc[:, [label]].loc[index].values[0]
else:
value = None
elif self._case == "regression":
if self.y_pred is not None:
value = self.y_pred.loc[index]
else:
value = self.model.predict(self.x_encoded.loc[[index]])[0]

if isinstance(value, pd.Series):
value = value.values[0]

return value
Loading

0 comments on commit 8a50775

Please sign in to comment.