From 93b1fcd206c888897cacf6262a183adceed7a9e1 Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Thu, 10 Oct 2024 11:36:05 +0200 Subject: [PATCH] rollbach on str type --- shapash/utils/utils.py | 12 ++++++------ tests/unit_tests/explainer/test_smart_plotter.py | 6 +++--- tests/unit_tests/utils/test_utils.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/shapash/utils/utils.py b/shapash/utils/utils.py index a537b78d..b103b3d7 100644 --- a/shapash/utils/utils.py +++ b/shapash/utils/utils.py @@ -73,14 +73,14 @@ def is_nested_list(object_param): return any(isinstance(elem, list) for elem in object_param) -def add_line_break(text, nbchar, maxlen=150): +def add_line_break(value, nbchar, maxlen=150): """ adding line break in string if necessary Parameters ---------- - text : string - string to check in order to add line break + value : string or oither type + if string to check in order to add line break nbchar : int number of characters before line break maxlen : int @@ -91,10 +91,10 @@ def add_line_break(text, nbchar, maxlen=150): string original text + line break """ - if isinstance(text, str): + if isinstance(value, str): length = 0 tot_length = 0 - input_word = text.split() + input_word = value.split() final_sep = [] for w in input_word[:-1]: length = length + len(w) @@ -113,7 +113,7 @@ def add_line_break(text, nbchar, maxlen=150): new_string = "".join(sum(zip(input_word, final_sep + [""]), ())[:-1]) + last_char return new_string else: - return str(text) + return value def truncate_str(text, maxlen=40): diff --git a/tests/unit_tests/explainer/test_smart_plotter.py b/tests/unit_tests/explainer/test_smart_plotter.py index b3144e8c..0b66d2fb 100644 --- a/tests/unit_tests/explainer/test_smart_plotter.py +++ b/tests/unit_tests/explainer/test_smart_plotter.py @@ -1860,11 +1860,11 @@ def test_interactions_plot_3(self): output = smart_explainer.plot.interactions_plot(col2, col1, violin_maxf=0) - assert np.array_equal(output.data[0].x, ["34.0"]) + assert np.array_equal(output.data[0].x, [34.0]) assert np.array_equal(output.data[0].y, [-1.4]) assert output.data[0].name == "PhD" - assert np.array_equal(output.data[1].x, ["27.0"]) + assert np.array_equal(output.data[1].x, [27.0]) assert np.array_equal(output.data[1].y, [-0.2]) assert output.data[1].name == "Master" @@ -1893,7 +1893,7 @@ def test_interactions_plot_4(self): output = smart_explainer.plot.interactions_plot(col1, col2, violin_maxf=0) - assert np.array_equal(output.data[0].x, ["520.0", "12800.0"]) + assert np.array_equal(output.data[0].x, [520, 12800]) assert np.array_equal(output.data[0].y, [-1.4, -0.2]) assert np.array_equal(output.data[0].marker.color, [34.0, 27.0]) diff --git a/tests/unit_tests/utils/test_utils.py b/tests/unit_tests/utils/test_utils.py index 0448f4dd..91e5959f 100644 --- a/tests/unit_tests/utils/test_utils.py +++ b/tests/unit_tests/utils/test_utils.py @@ -83,7 +83,7 @@ def test_truncate_str_3(self): def test_add_line_break_1(self): t = add_line_break(3453, 10) - assert t == "3453" + assert t == 3453 def test_add_line_break_2(self): t = add_line_break("this is a very long sentence in order to make a very great test", 10)