From 6b62ca221db87c2901c300ca7add634b39fa9a31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Fri, 10 Nov 2023 12:28:19 +0100 Subject: [PATCH] fix: improve ingredient detection output saved in DB - add the `in_taxonomy` tag recursively (each subingredient) - use all subingredients to compute ingredients_n, known_ingredients_n and unknown_ingredients_n - fix bug when using max with empty list as input --- robotoff/workers/tasks/import_image.py | 132 +++++++++++++----- .../workers/tasks/test_import_image.py | 1 - tests/unit/workers/tasks/test_import_image.py | 118 +++++++++++++++- 3 files changed, 212 insertions(+), 39 deletions(-) diff --git a/robotoff/workers/tasks/import_image.py b/robotoff/workers/tasks/import_image.py index f570e35b0a..fdd7fa04c0 100644 --- a/robotoff/workers/tasks/import_image.py +++ b/robotoff/workers/tasks/import_image.py @@ -6,6 +6,7 @@ import elasticsearch from elasticsearch.helpers import BulkIndexError from openfoodfacts import OCRResult +from openfoodfacts.taxonomy import Taxonomy from openfoodfacts.types import TaxonomyType from PIL import Image @@ -626,48 +627,105 @@ def extract_ingredients_job(product_id: ProductIdentifier, ocr_url: str): ] = output.entities # type: ignore # (we know it's an aggregated entity, so we can ignore the type) - image_prediction_data = dataclasses.asdict(output) - ingredient_taxonomy = get_taxonomy(TaxonomyType.ingredient) - - for entity in image_prediction_data["entities"]: - # This is just an extra check, we should have lang information - # available - if entity["lang"]: - lang_id = entity["lang"]["lang"] - try: - # Parse ingredients using Product Opener ingredient parser, - # and add it to the entity data - parsed_ingredients = parse_ingredients(entity["text"], lang_id) - except RuntimeError as e: - logger.info( - "Error while parsing ingredients, skipping " - "to the next ingredient list", - exc_info=e, - ) - continue - - known_ingredients_n = 0 - ingredients_n = len(parsed_ingredients) - for ingredient_data in parsed_ingredients: - ingredient_id = ingredient_data["id"] - ingredient_data["in_taxonomy"] = ( - ingredient_id in ingredient_taxonomy - ) - known_ingredients_n += int(ingredient_data["in_taxonomy"]) - - # We use the same terminology as Product Opener - entity["ingredients_n"] = ingredients_n - entity["known_ingredients_n"] = known_ingredients_n - entity["unknown_ingredients_n"] = ingredients_n - known_ingredients_n - entity["ingredients"] = parsed_ingredients - + ingredient_prediction_data = generate_ingredient_prediction_data(output) ImagePrediction.create( image=image_model, type="ner", model_name=ingredient_list.MODEL_NAME, model_version=ingredient_list.MODEL_VERSION, - data=image_prediction_data, + data=ingredient_prediction_data, timestamp=datetime.datetime.utcnow(), - max_confidence=max(entity.score for entity in entities), + max_confidence=max(entity.score for entity in entities) + if entities + else None, ) logger.info("create image prediction (ingredient detection) from %s", ocr_url) + + +def generate_ingredient_prediction_data( + ingredient_prediction_output: ingredient_list.IngredientPredictionOutput, +) -> JSONType: + """Generate a JSON-like object from the ingredient prediction output to + be saved in ImagePrediction data field. + + We remove the full text, as it's usually very long, and add a few + additional fields: + + - `ingredients_n`: the total number of ingredients + - `known_ingredients_n`: the number of known ingredients + - `unknown_ingredients_n`: the number of unknown ingredients + - `ingredients`: the parsed ingredients, in Product Opener format (with + the `in_taxonomy` field added) + + :param ingredient_prediction_output: the ingredient prediction output + :return: the generated JSON-like object + """ + ingredient_prediction_data = dataclasses.asdict(ingredient_prediction_output) + # Remove the full text, as it's usually very long + ingredient_prediction_data.pop("text") + ingredient_taxonomy = get_taxonomy(TaxonomyType.ingredient) + + for entity in ingredient_prediction_data["entities"]: + # This is just an extra check, we should have lang information + # available + if entity["lang"]: + lang_id = entity["lang"]["lang"] + try: + # Parse ingredients using Product Opener ingredient parser, + # and add it to the entity data + parsed_ingredients = parse_ingredients(entity["text"], lang_id) + except RuntimeError as e: + logger.info( + "Error while parsing ingredients, skipping " + "to the next ingredient list", + exc_info=e, + ) + continue + + ingredients_n, known_ingredients_n = add_ingredient_in_taxonomy_field( + parsed_ingredients, ingredient_taxonomy + ) + + # We use the same terminology as Product Opener + entity["ingredients_n"] = ingredients_n + entity["known_ingredients_n"] = known_ingredients_n + entity["unknown_ingredients_n"] = ingredients_n - known_ingredients_n + entity["ingredients"] = parsed_ingredients + + return ingredient_prediction_data + + +def add_ingredient_in_taxonomy_field( + parsed_ingredients: list[JSONType], ingredient_taxonomy: Taxonomy +) -> tuple[int, int]: + """Add the `in_taxonomy` field to each ingredient in `parsed_ingredients`. + + This function is called recursively to add the `in_taxonomy` field to each + sub-ingredient. It returns the total number of ingredients and the number + of known ingredients (including sub-ingredients). + + :param parsed_ingredients: a list of parsed ingredients, in Product Opener + format + :param ingredient_taxonomy: the ingredient taxonomy + :return: a (total_ingredients_n, known_ingredients_n) tuple + """ + ingredients_n = 0 + known_ingredients_n = 0 + for ingredient_data in parsed_ingredients: + ingredient_id = ingredient_data["id"] + in_taxonomy = ingredient_id in ingredient_taxonomy + ingredient_data["in_taxonomy"] = in_taxonomy + known_ingredients_n += int(in_taxonomy) + ingredients_n += 1 + + if "ingredients" in ingredient_data: + ( + sub_ingredients_n, + known_sub_ingredients_n, + ) = add_ingredient_in_taxonomy_field( + ingredient_data["ingredients"], ingredient_taxonomy + ) + ingredients_n += sub_ingredients_n + known_ingredients_n += known_sub_ingredients_n + + return ingredients_n, known_ingredients_n diff --git a/tests/integration/workers/tasks/test_import_image.py b/tests/integration/workers/tasks/test_import_image.py index aa98a0f7ed..5fa3b93075 100644 --- a/tests/integration/workers/tasks/test_import_image.py +++ b/tests/integration/workers/tasks/test_import_image.py @@ -97,7 +97,6 @@ def test_extract_ingredients_job(mocker, peewee_db): ) assert image_prediction is not None assert image_prediction.data == { - "text": full_text, "entities": [ { "end": 51, diff --git a/tests/unit/workers/tasks/test_import_image.py b/tests/unit/workers/tasks/test_import_image.py index eedeb476c7..775634b1dc 100644 --- a/tests/unit/workers/tasks/test_import_image.py +++ b/tests/unit/workers/tasks/test_import_image.py @@ -1,6 +1,11 @@ import pytest +from openfoodfacts.types import TaxonomyType -from robotoff.workers.tasks.import_image import get_text_from_bounding_box +from robotoff.taxonomy import get_taxonomy +from robotoff.workers.tasks.import_image import ( + add_ingredient_in_taxonomy_field, + get_text_from_bounding_box, +) from ...pytest_utils import get_ocr_result_asset @@ -41,3 +46,114 @@ def test_get_text_from_bounding_box( ocr_result, bounding_box, image_width, image_height ) assert text == expected_text + + +def test_add_ingredient_in_taxonomy_field(): + parsed_ingredients = [ + { + "id": "en:water", + "text": "water", + "percent_min": 33.3333333333333, + "percent_max": 100, + "percent_estimate": 66.6666666666667, + "vegan": "yes", + "vegetarian": "yes", + }, + { + "id": "en:salt", + "text": "salt", + "percent_min": 0, + "percent_max": 50, + "percent_estimate": 16.6666666666667, + "vegan": "yes", + "vegetarian": "yes", + }, + { + "id": "en:sugar", + "text": "sugar", + "percent_min": 0, + "percent_max": 33.3333333333333, + "percent_estimate": 16.6666666666667, + "vegan": "yes", + "vegetarian": "yes", + "ingredients": [ + { + "id": "en:glucose", + "text": "glucose", + "percent_min": 0, + "percent_max": 100, + "percent_estimate": 100, + "vegan": "yes", + "vegetarian": "yes", + }, + { + "id": "en:unknown-ingredient", + "text": "Unknown ingredient", + "percent_min": 0, + "percent_max": 100, + "percent_estimate": 100, + }, + ], + }, + ] + ingredient_taxonomy = get_taxonomy(TaxonomyType.ingredient, offline=True) + + total_ingredients_n, known_ingredients_n = add_ingredient_in_taxonomy_field( + parsed_ingredients, ingredient_taxonomy + ) + + assert total_ingredients_n == 5 + assert known_ingredients_n == 4 + + assert parsed_ingredients == [ + { + "id": "en:water", + "text": "water", + "percent_min": 33.3333333333333, + "percent_max": 100, + "percent_estimate": 66.6666666666667, + "vegan": "yes", + "vegetarian": "yes", + "in_taxonomy": True, + }, + { + "id": "en:salt", + "text": "salt", + "percent_min": 0, + "percent_max": 50, + "percent_estimate": 16.6666666666667, + "vegan": "yes", + "vegetarian": "yes", + "in_taxonomy": True, + }, + { + "id": "en:sugar", + "text": "sugar", + "percent_min": 0, + "percent_max": 33.3333333333333, + "percent_estimate": 16.6666666666667, + "vegan": "yes", + "vegetarian": "yes", + "in_taxonomy": True, + "ingredients": [ + { + "id": "en:glucose", + "text": "glucose", + "percent_min": 0, + "percent_max": 100, + "percent_estimate": 100, + "vegan": "yes", + "vegetarian": "yes", + "in_taxonomy": True, + }, + { + "id": "en:unknown-ingredient", + "text": "Unknown ingredient", + "percent_min": 0, + "percent_max": 100, + "percent_estimate": 100, + "in_taxonomy": False, + }, + ], + }, + ]