Skip to content

Commit

Permalink
Merge pull request #589 from guillaume-vignal/feature/smartplotter_si…
Browse files Browse the repository at this point in the history
…mplification

Refactor DataFrame Column Transformation to Avoid Future Warning
  • Loading branch information
guillaume-vignal authored Oct 10, 2024
2 parents 4660ec7 + 93b1fcd commit 3c8a621
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
20 changes: 10 additions & 10 deletions shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,6 @@ def contribution_plot(
--------
>>> xpl.plot.contribution_plot(0)
"""

if self._explainer._case == "classification":
label_num, _, label_value = self._explainer.check_label_name(label)

Expand Down Expand Up @@ -505,8 +504,13 @@ def contribution_plot(
else:
feature_values = self._explainer.x_init.loc[list_ind, col_name]

if self.explainer.x_init[col_name].dtype == 'bool':
feature_values = feature_values.astype(int)
if isinstance(col_name, list):
for el in col_name:
if feature_values[el].dtype == "bool":
feature_values[el] = feature_values[el].astype(int)
else:
if feature_values.dtype == "bool":
feature_values = feature_values.astype(int)

if col_is_group:
feature_values = project_feature_values_1d(
Expand Down Expand Up @@ -1131,13 +1135,9 @@ def interactions_plot(

# add break line to X label if necessary
max_len_by_row = max([round(50 / self._explainer.features_desc[feature_values1.columns.values[0]]), 8])
feature_values1.iloc[:, 0] = feature_values1.iloc[:, 0].apply(
add_line_break,
args=(
max_len_by_row,
120,
),
)
args = (max_len_by_row, 120)
feature_values_str = feature_values1.iloc[:, 0].apply(add_line_break, args=args)
feature_values1 = pd.DataFrame({feature_values1.columns[0]: feature_values_str})

# selecting the best plot : Scatter, Violin?
if col_value_count1 > violin_maxf:
Expand Down
6 changes: 4 additions & 2 deletions shapash/plots/plot_contribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def plot_scatter(

# add break line to X label if necessary
args = (max_len_by_row, 120)
feature_values.iloc[:, 0] = feature_values.iloc[:, 0].apply(add_line_break, args=args)
feature_values_str = feature_values.iloc[:, 0].apply(add_line_break, args=args)
feature_values = pd.DataFrame({column_name: feature_values_str})

if pred is not None:
hv_text = [f"Id: {x}<br />Predict: {y}" for x, y in zip(feature_values.index, pred.values.flatten())]
Expand Down Expand Up @@ -267,7 +268,8 @@ def plot_violin(

# add break line to X label if necessary
args = (max_len_by_row, 120)
feature_values.iloc[:, 0] = feature_values.iloc[:, 0].apply(add_line_break, args=args)
feature_values_str = feature_values.iloc[:, 0].apply(add_line_break, args=args)
feature_values = pd.DataFrame({column_name: feature_values_str})

contributions = contributions.loc[feature_values.index]
if pred is not None:
Expand Down
12 changes: 6 additions & 6 deletions shapash/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def is_nested_list(object_param):
return any(isinstance(elem, list) for elem in object_param)


def add_line_break(text, nbchar, maxlen=150):
def add_line_break(value, nbchar, maxlen=150):
"""
adding line break in string if necessary
Parameters
----------
text : string
string to check in order to add line break
value : string or oither type
if string to check in order to add line break
nbchar : int
number of characters before line break
maxlen : int
Expand All @@ -91,10 +91,10 @@ def add_line_break(text, nbchar, maxlen=150):
string
original text + line break
"""
if isinstance(text, str):
if isinstance(value, str):
length = 0
tot_length = 0
input_word = text.split()
input_word = value.split()
final_sep = []
for w in input_word[:-1]:
length = length + len(w)
Expand All @@ -113,7 +113,7 @@ def add_line_break(text, nbchar, maxlen=150):
new_string = "".join(sum(zip(input_word, final_sep + [""]), ())[:-1]) + last_char
return new_string
else:
return text
return value


def truncate_str(text, maxlen=40):
Expand Down

0 comments on commit 3c8a621

Please sign in to comment.