Skip to content

Commit

Permalink
fix test on _plot_features_import
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-vignal committed Oct 4, 2024
1 parent cd851bd commit f0f5aee
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tests/unit_tests/explainer/test_smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from shapash.backend import ShapBackend
from shapash.explainer.multi_decorator import MultiDecorator
from shapash.explainer.smart_state import SmartState
from shapash.plots.plot_feature_importance import _plot_features_import
from shapash.plots.plot_line_comparison import plot_line_comparison
from shapash.style.style_utils import get_palette
from shapash.utils.check import check_model
Expand Down Expand Up @@ -1126,8 +1127,9 @@ def test_plot_features_import_1(self):
"""
Unit test plot features import 1
"""
xpl = self.smart_explainer
serie1 = pd.Series([0.131, 0.51], index=["col1", "col2"])
output = self.smart_explainer.plot._plot_features_import(serie1)
output = _plot_features_import(serie1, xpl.plot._style_dict, {})
data = go.Bar(x=serie1, y=serie1.index, name="Global", orientation="h")

expected_output = go.Figure(data=data)
Expand All @@ -1140,9 +1142,10 @@ def test_plot_features_import_2(self):
"""
Unit test plot features import 2
"""
xpl = self.smart_explainer
serie1 = pd.Series([0.131, 0.51], index=["col1", "col2"])
serie2 = pd.Series([0.33, 0.11], index=["col1", "col2"])
output = self.smart_explainer.plot._plot_features_import(serie1, serie2)
output = _plot_features_import(serie1, xpl.plot._style_dict, {}, feature_imp2=serie2)
data1 = go.Bar(x=serie1, y=serie1.index, name="Global", orientation="h")
data2 = go.Bar(x=serie2, y=serie2.index, name="Subset", orientation="h")
expected_output = go.Figure(data=[data2, data1])
Expand Down

0 comments on commit f0f5aee

Please sign in to comment.