diff --git a/fast_autocomplete/normalize.py b/fast_autocomplete/normalize.py index 41fd424..0bd3e67 100644 --- a/fast_autocomplete/normalize.py +++ b/fast_autocomplete/normalize.py @@ -12,11 +12,14 @@ def normalize_node_name(name, extra_chars=None): + if name is None: + return '' name = name[:MAX_WORD_LENGTH] - result = _normalized_lfu_cache.get(name) + key = name if extra_chars is None else f"{name}{extra_chars}" + result = _normalized_lfu_cache.get(key) if result == -1: result = _get_normalized_node_name(name, extra_chars=extra_chars) - _normalized_lfu_cache.set(name, result) + _normalized_lfu_cache.set(key, result) return result @@ -32,6 +35,8 @@ def remove_any_special_character(name): """ Only remove invalid characters from a name. Useful for cleaning the user's original word. """ + if name is None: + return '' name = name.lower()[:MAX_WORD_LENGTH] _remove_invalid_chars.prev_x = '' @@ -43,7 +48,7 @@ def _get_normalized_node_name(name, extra_chars=None): result = [] last_i = None for i in name: - if i in valid_chars_for_node_name or extra_chars and i in extra_chars: + if i in valid_chars_for_node_name or (extra_chars and i in extra_chars): if i == '-': i = ' ' elif (i in valid_chars_for_integer and last_i in valid_chars_for_string) or (i in valid_chars_for_string and last_i in valid_chars_for_integer): diff --git a/tests/test_normalize.py b/tests/test_normalize.py index 6f84f7c..60aa51b 100644 --- a/tests/test_normalize.py +++ b/tests/test_normalize.py @@ -1,5 +1,5 @@ import pytest -from fast_autocomplete.normalize import remove_any_special_character +from fast_autocomplete.normalize import remove_any_special_character, normalize_node_name class TestMisc: @@ -9,7 +9,20 @@ class TestMisc: ('HONDA and Toyota!', 'honda and toyota'), (r'bmw? \#1', 'bmw 1'), (r'bmw? \#', 'bmw'), + (None, ''), ]) - def test_extend_and_repeat(self, name, expected_result): + def test_remove_any_special_character(self, name, expected_result): result = remove_any_special_character(name) assert expected_result == result + + @pytest.mark.parametrize("name, extra_chars, expected_result", [ + ('type-r', None, 'type r'), + ('HONDA and Toyota!', None, 'honda and toyota'), + (r'bmw? \#1', None, 'bmw 1'), + (r'bmw? \#', None, 'bmw'), + (r'bmw? \#', {'#'}, 'bmw #'), + (None, None, ''), + ]) + def test_normalize_node_name(self, name, extra_chars, expected_result): + result = normalize_node_name(name, extra_chars=extra_chars) + assert expected_result == result