diff --git a/lingua_franca/parse.py b/lingua_franca/parse.py index 303baedd..cf5a6fb0 100644 --- a/lingua_franca/parse.py +++ b/lingua_franca/parse.py @@ -14,6 +14,8 @@ # limitations under the License. # from difflib import SequenceMatcher +import re + from lingua_franca.time import now_local from lingua_franca.lang import get_primary_lang_code @@ -78,7 +80,20 @@ def match_one(query, choices): return best -def extract_numbers(text, short_scale=True, ordinals=False, lang=None): +def normalize_decimals(text, decimal): + """ + Replace 'decimal' with decimal periods so Python can floatify them + """ + regex = r"\b\d+" + decimal + r"{1}\d+\b" + sanitize_decimals = re.compile(regex) + for _, match in enumerate(re.finditer(sanitize_decimals, text)): + text = text.replace(match.group( + 0), match.group(0).replace(decimal, '.')) + return text + + +def extract_numbers(text, short_scale=True, ordinals=False, lang=None, + decimal='.'): """ Takes in a string and extracts a list of numbers. @@ -90,9 +105,16 @@ def extract_numbers(text, short_scale=True, ordinals=False, lang=None): See https://en.wikipedia.org/wiki/Names_of_large_numbers ordinals (bool): consider ordinal numbers, e.g. third=3 instead of 1/3 lang (str): the BCP-47 code for the language to use, None uses default + decimal (str): character to use as decimal point. defaults to '.' Returns: list: list of extracted numbers as floats, or empty list if none found + Note: + will always extract numbers formatted with a decimal dot/full stop, + such as '3.5', even if 'decimal' is specified. """ + if decimal != '.': + text = normalize_decimals(text, decimal) + lang_code = get_primary_lang_code(lang) if lang_code == "en": return extract_numbers_en(text, short_scale, ordinals) @@ -112,7 +134,8 @@ def extract_numbers(text, short_scale=True, ordinals=False, lang=None): return [] -def extract_number(text, short_scale=True, ordinals=False, lang=None): +def extract_number(text, short_scale=True, ordinals=False, lang=None, + decimal='.'): """Takes in a string and extracts a number. Args: @@ -123,10 +146,17 @@ def extract_number(text, short_scale=True, ordinals=False, lang=None): See https://en.wikipedia.org/wiki/Names_of_large_numbers ordinals (bool): consider ordinal numbers, e.g. third=3 instead of 1/3 lang (str): the BCP-47 code for the language to use, None uses default + decimal (str): character to use as decimal point. defaults to '.' Returns: (int, float or False): The number extracted or False if the input text contains no numbers + Note: + will always extract numbers formatted with a decimal dot/full stop, + such as '3.5', even if 'decimal' is specified. """ + if decimal != '.': + text = normalize_decimals(text, decimal) + lang_code = get_primary_lang_code(lang) if lang_code == "en": return extractnumber_en(text, short_scale=short_scale, diff --git a/test/test_parse.py b/test/test_parse.py index 01aec528..afc1fddc 100644 --- a/test/test_parse.py +++ b/test/test_parse.py @@ -123,6 +123,12 @@ def test_extract_number(self): short_scale=False), 1e12) self.assertEqual(extract_number("this is the billionth test", short_scale=False), 1e-12) + + # Test decimal normalization + self.assertEqual(extract_number("4,4", decimal=','), 4.4) + self.assertEqual(extract_number("we have 3,5 kilometers to go", + decimal=','), 3.5) + # TODO handle this case # self.assertEqual( # extract_number("6 dot six six six"), @@ -703,6 +709,11 @@ def test_multiple_numbers(self): self.assertEqual(extract_numbers("this is a seven eight nine and a" " half test"), [7.0, 8.0, 9.5]) + self.assertEqual(extract_numbers("this is a seven eight 9,5 test", + decimal=','), + [7.0, 8.0, 9.5]) + self.assertEqual(extract_numbers("this is a 7,0 8.0 9,6 test", + decimal=','), [7.0, 8.0, 9.6]) def test_contractions(self): self.assertEqual(normalize("ain't"), "is not")