diff --git a/tests/unit_tests/explainer/test_smart_plotter.py b/tests/unit_tests/explainer/test_smart_plotter.py index a12abe4a..aeb799f3 100644 --- a/tests/unit_tests/explainer/test_smart_plotter.py +++ b/tests/unit_tests/explainer/test_smart_plotter.py @@ -17,6 +17,7 @@ from shapash.backend import ShapBackend from shapash.explainer.multi_decorator import MultiDecorator from shapash.explainer.smart_state import SmartState +from shapash.plots.plot_bar_chart import plot_bar_chart from shapash.plots.plot_feature_importance import _plot_features_import from shapash.plots.plot_line_comparison import plot_line_comparison from shapash.style.style_utils import get_palette @@ -661,7 +662,7 @@ def test_plot_bar_chart_1(self): ) expected_output_fig = go.Figure(data=bars, layout=go.Layout(yaxis=dict(type="category"))) self.smart_explainer._case = "regression" - fig_output = self.smart_explainer.plot._plot_bar_chart("ind", var_dict, x_val, contributions) + fig_output = plot_bar_chart("ind", var_dict, x_val, contributions) for part in list(zip(fig_output.data, expected_output_fig.data)): assert part[0].x == part[1].x assert part[0].y == part[1].y @@ -684,7 +685,7 @@ def test_plot_bar_chart_2(self): expected_output_fig = go.Figure(data=bars, layout=go.Layout(yaxis=dict(type="category"))) self.smart_explainer._case = "regression" - fig_output = self.smart_explainer.plot._plot_bar_chart("ind", var_dict, x_val, contributions) + fig_output = plot_bar_chart("ind", var_dict, x_val, contributions) for part in list(zip(fig_output.data, expected_output_fig.data)): assert part[0].x == part[1].x assert part[0].y == part[1].y