From f45cd39903a8c0a009a4e7a3eae12ffa80233b24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Mon, 13 Nov 2023 10:50:46 +0100 Subject: [PATCH] feat: add bounding box info to IngredientPredictionAggregatedEntity --- robotoff/app/api.py | 12 +----------- .../prediction/ingredient_list/__init__.py | 18 +++++++++++++++--- .../workers/tasks/test_import_image.py | 2 ++ 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/robotoff/app/api.py b/robotoff/app/api.py index 28a53c93db..3bf59d3bbd 100644 --- a/robotoff/app/api.py +++ b/robotoff/app/api.py @@ -647,17 +647,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): ], model_version=model_version, ) - - output_dict = dataclasses.asdict(output) - - if aggregation_strategy != "NONE": - # Add bounding boxes to entities - for entity in output_dict["entities"]: - entity["bounding_boxes"] = ocr_result.get_match_bounding_box( - entity["start"], entity["end"] - ) - - resp.media = output_dict + resp.media = dataclasses.asdict(output) class UpdateDatasetResource: diff --git a/robotoff/prediction/ingredient_list/__init__.py b/robotoff/prediction/ingredient_list/__init__.py index 0eb55789a7..ab5c1edcda 100644 --- a/robotoff/prediction/ingredient_list/__init__.py +++ b/robotoff/prediction/ingredient_list/__init__.py @@ -1,7 +1,7 @@ import dataclasses import functools from pathlib import Path -from typing import Optional, Union +from typing import Union import numpy as np from openfoodfacts.ocr import OCRResult @@ -39,7 +39,10 @@ class IngredientPredictionAggregatedEntity: # entity text (without organic or allergen mentions) text: str # language prediction of the entity text - lang: Optional[LanguagePrediction] = None + lang: LanguagePrediction | None = None + # the bounding box of the entity in absolute coordinates + # (y_min, x_min, y_max, x_max), or None if not available + bounding_box: tuple[int, int, int, int] | None = None @dataclasses.dataclass @@ -102,7 +105,16 @@ def predict_from_ocr( predictions = predict_batch( [text], aggregation_strategy, predict_lang, model_version ) - return predictions[0] + prediction = predictions[0] + + for entity in prediction.entities: + if isinstance(entity, IngredientPredictionAggregatedEntity): + # Add the bounding box to the entity + entity.bounding_box = ocr_result.get_match_bounding_box( + entity.start, entity.end + ) + + return prediction @functools.cache diff --git a/tests/integration/workers/tasks/test_import_image.py b/tests/integration/workers/tasks/test_import_image.py index 5fa3b93075..45e896254b 100644 --- a/tests/integration/workers/tasks/test_import_image.py +++ b/tests/integration/workers/tasks/test_import_image.py @@ -33,6 +33,7 @@ def test_extract_ingredients_job(mocker, peewee_db): score=0.9, text="water, salt, sugar.", lang=LanguagePrediction(lang="en", confidence=0.9), + bounding_box=(0, 0, 100, 100), ) ] parsed_ingredients = [ @@ -112,6 +113,7 @@ def test_extract_ingredients_job(mocker, peewee_db): {"in_taxonomy": True, **ingredient} for ingredient in parsed_ingredients ], + "bounding_box": [0, 0, 100, 100], } ], }