From 80d3a09af24447570a7c09362e6c416194add671 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Thu, 31 Aug 2023 15:06:26 +0200 Subject: [PATCH] fix: update normalize_weight function after Pint upgrade --- robotoff/prediction/ocr/product_weight.py | 9 ++++++++- tests/unit/prediction/ocr/test_product_weight.py | 7 +++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/robotoff/prediction/ocr/product_weight.py b/robotoff/prediction/ocr/product_weight.py index 8df0c6873c..e8257851f3 100644 --- a/robotoff/prediction/ocr/product_weight.py +++ b/robotoff/prediction/ocr/product_weight.py @@ -1,4 +1,5 @@ import functools +import math import re from typing import Optional, Union @@ -42,7 +43,13 @@ def normalize_weight(value: str, unit: str) -> tuple[float, str]: else: raise ValueError(f"unknown unit: {quantity.u}") - return normalized_quantity.magnitude, normalized_unit + # Rounding errors due to float may occur with Pint, + # round normalized value to floor if there is no significant difference + normalized_value = normalized_quantity.magnitude + if math.isclose(math.floor(normalized_value), normalized_value): + normalized_value = math.floor(normalized_value) + + return normalized_value, normalized_unit def is_valid_weight(weight_value: str) -> bool: diff --git a/tests/unit/prediction/ocr/test_product_weight.py b/tests/unit/prediction/ocr/test_product_weight.py index 0e3fa256ce..87287afa17 100644 --- a/tests/unit/prediction/ocr/test_product_weight.py +++ b/tests/unit/prediction/ocr/test_product_weight.py @@ -1,3 +1,5 @@ +import math + import pytest from robotoff.prediction.ocr.dataclass import OCRRegex @@ -66,8 +68,9 @@ def test_product_weight_with_ending_mention_regex(input_str: str, is_match: bool ], ) def test_normalize_weight(value: str, unit: str, expected: tuple[float, str]): - result = normalize_weight(value, unit) - assert result == expected + normalized_value, normalized_unit = normalize_weight(value, unit) + assert math.isclose(normalized_value, expected[0]) + assert normalized_unit == expected[1] @pytest.mark.parametrize(